From e1b6187771bbb1a2413624ef9a43d46c82bbf17f Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Tue, 7 Jan 2025 17:17:09 +0000 Subject: [PATCH] Feat: Refactor upload method and reduce duplication - Restored type hints in `upload` for improved type safety. - Added `get_mpu_kwargs` to centralize shared keyword arguments. - Simplified `upload` and `mpu_upload` implementations by reusing `get_mpu_kwargs`. - Reduced code duplication across `_mpu.py` and `_multipart.py`. --- odc/geo/cog/_az.py | 44 ++++--------------------------------- odc/geo/cog/_mpu.py | 46 +++++++++++++++++++++++++++++++++++++++ odc/geo/cog/_multipart.py | 36 ++++++++++++++++++------------ odc/geo/cog/_s3.py | 29 ++++-------------------- odc/geo/geom.py | 5 +++-- 5 files changed, 79 insertions(+), 81 deletions(-) diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index d0f91a4b..78c5e212 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -1,10 +1,6 @@ import base64 -from typing import Any, Union +from typing import Any -import dask -from dask.delayed import Delayed - -from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -132,41 +128,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None): """ return DelayedAzureWriter(self, kw) - def upload( - self, - chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], - *, - mk_header: Any = None, - mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, - writes_per_chunk: int = 1, - spill_sz: int = 20 * (1 << 20), - client: Any = None, - **kw, - ) -> Delayed: - """ - Upload chunks to Azure Blob Storage with multipart uploads. - - :param chunks: Dask bag of chunks to upload. - :param mk_header: Function to create header data. - :param mk_footer: Function to create footer data. - :param user_kw: User-provided metadata for the upload. - :param writes_per_chunk: Number of writes per chunk. - :param spill_sz: Spill size for buffering data. - :param client: Dask client for distributed execution. - :return: A Dask delayed object representing the finalised upload. - """ - write = self.writer(kw, client=client) if spill_sz else None - return mpu_write( - chunks, - write, - mk_header=mk_header, - mk_footer=mk_footer, - user_kw=user_kw, - writes_per_chunk=writes_per_chunk, - spill_sz=spill_sz, - dask_name_prefix="azure-finalise", - ) + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for Azure.""" + return "azure-finalise" class DelayedAzureWriter(AzureLimits): diff --git a/odc/geo/cog/_mpu.py b/odc/geo/cog/_mpu.py index f1776d8d..ddc453a5 100644 --- a/odc/geo/cog/_mpu.py +++ b/odc/geo/cog/_mpu.py @@ -495,3 +495,49 @@ def _finalizer_dask_op( _, rr = _root.flush(write, leftPartId=1, finalise=True) return rr + + +def get_mpu_kwargs( + mk_header=None, + mk_footer=None, + user_kw=None, + writes_per_chunk=1, + spill_sz=20 * (1 << 20), + client=None, +) -> dict: + """ + Construct shared keyword arguments for multipart uploads. + """ + return { + "mk_header": mk_header, + "mk_footer": mk_footer, + "user_kw": user_kw, + "writes_per_chunk": writes_per_chunk, + "spill_sz": spill_sz, + "client": client, + } + + +def mpu_upload( + chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], + *, + writer: Any, + dask_name_prefix: str, + **kw, +) -> "Delayed": + """Shared logic for multipart uploads to storage services.""" + client = kw.pop("client", None) + writer_kw = dict(kw) + if client is not None: + writer_kw["client"] = client + spill_sz = kw.get("spill_sz", 20 * (1 << 20)) + if spill_sz: + write = writer(writer_kw) + else: + write = None + return mpu_write( + chunks, + write, + dask_name_prefix=dask_name_prefix, + **kw, # everything else remains + ) diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index 0fc9b4c8..c7376bfc 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -9,6 +9,9 @@ from abc import ABC, abstractmethod from typing import Any, Union, TYPE_CHECKING +from dask.delayed import Delayed +from ._mpu import get_mpu_kwargs, mpu_upload + if TYPE_CHECKING: # pylint: disable=import-outside-toplevel,import-error import dask.bag @@ -53,6 +56,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: """ @abstractmethod + def dask_name_prefix(self) -> str: + """Return the dask name prefix specific to the backend.""" + def upload( self, chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], @@ -63,17 +69,19 @@ def upload( writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, - **kw, - ) -> Any: - """ - Orchestrate the upload process with multipart uploads. - - :param chunks: Dask bag of chunks to upload. - :param mk_header: Function to create header data. - :param mk_footer: Function to create footer data. - :param user_kw: User-provided metadata for the upload. - :param writes_per_chunk: Number of writes per chunk. - :param spill_sz: Spill size for buffering data. - :param client: Dask client for distributed execution. - :return: A Dask delayed object representing the finalised upload. - """ + ) -> Delayed: + """High-level upload that calls mpu_upload under the hood.""" + kwargs = get_mpu_kwargs( + mk_header=mk_header, + mk_footer=mk_footer, + user_kw=user_kw, + writes_per_chunk=writes_per_chunk, + spill_sz=spill_sz, + client=client, + ) + return mpu_upload( + chunks, + writer=self.writer, + dask_name_prefix=self.dask_name_prefix(), + **kwargs, + ) diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 367b31c9..0ae6981f 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -9,7 +9,7 @@ from cachetools import cached -from ._mpu import PartsWriter, SomeData, mpu_write +from ._mpu import PartsWriter, SomeData from ._multipart import MultiPartUploadBase if TYPE_CHECKING: @@ -197,30 +197,9 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter: writer.prep_client(client) return writer - def upload( - self, - chunks: "dask.bag.Bag" | list["dask.bag.Bag"], - *, - mk_header: Any = None, - mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, - writes_per_chunk: int = 1, - spill_sz: int = 20 * (1 << 20), - client: Any = None, - **kw, - ) -> "Delayed": - """Upload chunks to S3 with multipart uploads.""" - write = self.writer(kw, client=client) if spill_sz else None - return mpu_write( - chunks, - write, - mk_header=mk_header, - mk_footer=mk_footer, - user_kw=user_kw, - writes_per_chunk=writes_per_chunk, - spill_sz=spill_sz, - dask_name_prefix="s3finalise", - ) + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for S3.""" + return "s3finalise" def _safe_get(v, timeout=0.1): diff --git a/odc/geo/geom.py b/odc/geo/geom.py index 6aff6ef4..f1e89850 100644 --- a/odc/geo/geom.py +++ b/odc/geo/geom.py @@ -320,7 +320,6 @@ def boundary(self, pts_per_side: int = 2) -> "Geometry": self.crs, ) - def qr2sample( self, n: int, @@ -358,7 +357,8 @@ def qr2sample( n_side = int(numpy.round(sample_density * min(nx, ny))) + 1 n_side = max(2, n_side) edge_pts = [ - (float(ep[0]), float(ep[1])) for ep in list(self.boundary(n_side).coords[:-1]) + (float(ep[0]), float(ep[1])) + for ep in list(self.boundary(n_side).coords[:-1]) ] if padding is None: padding = 0.3 * min(nx, ny) / (n_side - 1) @@ -377,6 +377,7 @@ def qr2sample( return multipoint(coords, self.crs) + def wrap_shapely(method): """ Takes a method that expects shapely geometry arguments and converts it to a method that operates