Skip to content

Commit

Permalink
Feat: Refactor upload method and reduce duplication
Browse files Browse the repository at this point in the history
- Restored type hints in `upload` for improved type safety.
- Added `get_mpu_kwargs` to centralize shared keyword arguments.
- Simplified `upload` and `mpu_upload` implementations by reusing `get_mpu_kwargs`.
- Reduced code duplication across `_mpu.py` and `_multipart.py`.
  • Loading branch information
wietzesuijker committed Jan 7, 2025
1 parent 3b0223c commit e1b6187
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 81 deletions.
44 changes: 4 additions & 40 deletions odc/geo/cog/_az.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import base64
from typing import Any, Union
from typing import Any

import dask
from dask.delayed import Delayed

from ._mpu import mpu_write
from ._multipart import MultiPartUploadBase


Expand Down Expand Up @@ -132,41 +128,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None):
"""
return DelayedAzureWriter(self, kw)

def upload(
self,
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
*,
mk_header: Any = None,
mk_footer: Any = None,
user_kw: dict[str, Any] | None = None,
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
**kw,
) -> Delayed:
"""
Upload chunks to Azure Blob Storage with multipart uploads.
:param chunks: Dask bag of chunks to upload.
:param mk_header: Function to create header data.
:param mk_footer: Function to create footer data.
:param user_kw: User-provided metadata for the upload.
:param writes_per_chunk: Number of writes per chunk.
:param spill_sz: Spill size for buffering data.
:param client: Dask client for distributed execution.
:return: A Dask delayed object representing the finalised upload.
"""
write = self.writer(kw, client=client) if spill_sz else None
return mpu_write(
chunks,
write,
mk_header=mk_header,
mk_footer=mk_footer,
user_kw=user_kw,
writes_per_chunk=writes_per_chunk,
spill_sz=spill_sz,
dask_name_prefix="azure-finalise",
)
def dask_name_prefix(self) -> str:
"""Return the Dask name prefix for Azure."""
return "azure-finalise"


class DelayedAzureWriter(AzureLimits):
Expand Down
46 changes: 46 additions & 0 deletions odc/geo/cog/_mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,49 @@ def _finalizer_dask_op(

_, rr = _root.flush(write, leftPartId=1, finalise=True)
return rr


def get_mpu_kwargs(
mk_header=None,
mk_footer=None,
user_kw=None,
writes_per_chunk=1,
spill_sz=20 * (1 << 20),
client=None,
) -> dict:
"""
Construct shared keyword arguments for multipart uploads.
"""
return {
"mk_header": mk_header,
"mk_footer": mk_footer,
"user_kw": user_kw,
"writes_per_chunk": writes_per_chunk,
"spill_sz": spill_sz,
"client": client,
}


def mpu_upload(
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
*,
writer: Any,
dask_name_prefix: str,
**kw,
) -> "Delayed":
"""Shared logic for multipart uploads to storage services."""
client = kw.pop("client", None)
writer_kw = dict(kw)
if client is not None:
writer_kw["client"] = client
spill_sz = kw.get("spill_sz", 20 * (1 << 20))
if spill_sz:
write = writer(writer_kw)
else:
write = None
return mpu_write(
chunks,
write,
dask_name_prefix=dask_name_prefix,
**kw, # everything else remains
)
36 changes: 22 additions & 14 deletions odc/geo/cog/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from abc import ABC, abstractmethod
from typing import Any, Union, TYPE_CHECKING

from dask.delayed import Delayed
from ._mpu import get_mpu_kwargs, mpu_upload

if TYPE_CHECKING:
# pylint: disable=import-outside-toplevel,import-error
import dask.bag
Expand Down Expand Up @@ -53,6 +56,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any:
"""

@abstractmethod
def dask_name_prefix(self) -> str:
"""Return the dask name prefix specific to the backend."""

def upload(
self,
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
Expand All @@ -63,17 +69,19 @@ def upload(
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
**kw,
) -> Any:
"""
Orchestrate the upload process with multipart uploads.
:param chunks: Dask bag of chunks to upload.
:param mk_header: Function to create header data.
:param mk_footer: Function to create footer data.
:param user_kw: User-provided metadata for the upload.
:param writes_per_chunk: Number of writes per chunk.
:param spill_sz: Spill size for buffering data.
:param client: Dask client for distributed execution.
:return: A Dask delayed object representing the finalised upload.
"""
) -> Delayed:
"""High-level upload that calls mpu_upload under the hood."""
kwargs = get_mpu_kwargs(
mk_header=mk_header,
mk_footer=mk_footer,
user_kw=user_kw,
writes_per_chunk=writes_per_chunk,
spill_sz=spill_sz,
client=client,
)
return mpu_upload(
chunks,
writer=self.writer,
dask_name_prefix=self.dask_name_prefix(),
**kwargs,
)
29 changes: 4 additions & 25 deletions odc/geo/cog/_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from cachetools import cached

from ._mpu import PartsWriter, SomeData, mpu_write
from ._mpu import PartsWriter, SomeData
from ._multipart import MultiPartUploadBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -197,30 +197,9 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter:
writer.prep_client(client)
return writer

def upload(
self,
chunks: "dask.bag.Bag" | list["dask.bag.Bag"],
*,
mk_header: Any = None,
mk_footer: Any = None,
user_kw: dict[str, Any] | None = None,
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
**kw,
) -> "Delayed":
"""Upload chunks to S3 with multipart uploads."""
write = self.writer(kw, client=client) if spill_sz else None
return mpu_write(
chunks,
write,
mk_header=mk_header,
mk_footer=mk_footer,
user_kw=user_kw,
writes_per_chunk=writes_per_chunk,
spill_sz=spill_sz,
dask_name_prefix="s3finalise",
)
def dask_name_prefix(self) -> str:
"""Return the Dask name prefix for S3."""
return "s3finalise"


def _safe_get(v, timeout=0.1):
Expand Down
5 changes: 3 additions & 2 deletions odc/geo/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def boundary(self, pts_per_side: int = 2) -> "Geometry":
self.crs,
)


def qr2sample(
self,
n: int,
Expand Down Expand Up @@ -358,7 +357,8 @@ def qr2sample(
n_side = int(numpy.round(sample_density * min(nx, ny))) + 1
n_side = max(2, n_side)
edge_pts = [
(float(ep[0]), float(ep[1])) for ep in list(self.boundary(n_side).coords[:-1])
(float(ep[0]), float(ep[1]))
for ep in list(self.boundary(n_side).coords[:-1])
]
if padding is None:
padding = 0.3 * min(nx, ny) / (n_side - 1)
Expand All @@ -377,6 +377,7 @@ def qr2sample(

return multipoint(coords, self.crs)


def wrap_shapely(method):
"""
Takes a method that expects shapely geometry arguments and converts it to a method that operates
Expand Down

0 comments on commit e1b6187

Please sign in to comment.