Skip to content

Commit

Permalink
refactor!: use a configuration object
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Aug 7, 2023
1 parent d5dc861 commit 8a96b21
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 153 deletions.
6 changes: 4 additions & 2 deletions src/stac_asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@


from .client import Client
from .config import Config
from .earthdata_client import EarthdataClient
from .errors import (
AssetDownloadError,
AssetOverwriteError,
CantIncludeAndExclude,
CannotIncludeAndExclude,
DownloadError,
DownloadWarning,
)
Expand All @@ -34,8 +35,9 @@
"DownloadWarning",
"AssetDownloadError",
"AssetOverwriteError",
"CantIncludeAndExclude",
"CannotIncludeAndExclude",
"Client",
"Config",
"DownloadError",
"EarthdataClient",
"FileNameStrategy",
Expand Down
24 changes: 13 additions & 11 deletions src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click
from pystac import Item, ItemCollection

from . import functions
from . import Config, functions


@click.group()
Expand All @@ -30,6 +30,7 @@ def cli() -> None:
help="Asset keys to exclude (can't be used with include)",
multiple=True,
)
@click.option("-f", "--file-name", help="The output file name")
@click.option(
"-q",
"--quiet",
Expand Down Expand Up @@ -58,6 +59,7 @@ def download(
directory: Optional[str],
include: List[str],
exclude: List[str],
file_name: Optional[str],
quiet: bool,
s3_requester_pays: bool,
warn: bool,
Expand Down Expand Up @@ -88,6 +90,14 @@ def download(
if directory is None:
directory = os.getcwd()

config = Config(
include=include,
exclude=exclude,
file_name=file_name,
s3_requester_pays=s3_requester_pays,
warn=warn,
)

type_ = input_dict.get("type")
if type_ is None:
print("ERROR: missing 'type' field on input dictionary", file=sys.stderr)
Expand All @@ -101,11 +111,7 @@ def download(
functions.download_item(
item,
directory,
include=include or None,
exclude=exclude or None,
item_file_name=None,
s3_requester_pays=s3_requester_pays,
warn_on_download_error=warn,
config=config,
)
)
elif type_ == "FeatureCollection":
Expand All @@ -114,11 +120,7 @@ def download(
functions.download_item_collection(
item_collection,
directory,
include=include or None,
exclude=exclude or None,
item_collection_file_name=None,
s3_requester_pays=s3_requester_pays,
warn_on_download_error=warn,
config=config,
)
)
else:
Expand Down
102 changes: 34 additions & 68 deletions src/stac_asset/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from pystac import Asset, Item, ItemCollection
from yarl import URL

from .config import Config
from .errors import (
AssetDownloadError,
AssetOverwriteError,
CantIncludeAndExclude,
DownloadError,
DownloadWarning,
)
Expand Down Expand Up @@ -123,70 +123,52 @@ async def download_item(
self,
item: Item,
directory: PathLikeObject,
*,
make_directory: bool = False,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
item_file_name: Optional[str] = "item.json",
include_self_link: bool = True,
asset_file_name_strategy: FileNameStrategy = FileNameStrategy.FILE_NAME,
warn_on_download_error: bool = False,
config: Optional[Config] = None,
) -> Item:
"""Downloads an item and all of its assets to the given directory.
Args:
item: The item to download
directory: The root location of the downloaded files
make_directory: If true and the directory doesn't exist, create the
output directory before downloading
include: Asset keys to download. If not provided, all asset keys
will be downloaded.
exclude: Asset keys to not download. If not provided, all asset keys
will be downloaded.
item_file_name: The name of the item file. If not provided, the item
will not be written to the filesystem (only the assets will be
downloaded).
include_self_link: Whether to include a self link on the item.
Unused if ``item_file_name=None``.
asset_file_name_strategy: The :py:class:`FileNameStrategy` to use
for naming asset files
warn_on_download_error: Instead of raising any errors encountered
while downloading, warn and delete the asset from the item
config: Configuration for downloading the item
Returns:
Item: The :py:class:`~pystac.Item`, with updated asset hrefs
Raises:
CantIncludeAndExclude: Raised if both include and exclude are not None.
"""
if include is not None and exclude is not None:
raise CantIncludeAndExclude()
if config is None:
config = Config()
else:
config.validate()

directory_as_path = Path(directory)
if not directory_as_path.exists():
if make_directory:
if config.make_directory:
directory_as_path.mkdir()
else:
raise FileNotFoundError(f"output directory does not exist: {directory}")

if item_file_name:
item_path = directory_as_path / item_file_name
if config.file_name:
item_path = directory_as_path / config.file_name
else:
item_path = None
self_href = item.get_self_href()
if self_href:
item_path = directory_as_path / os.path.basename(self_href)
else:
item_path = None

tasks: List[Task[Any]] = list()
file_names: Dict[str, str] = dict()
item.make_asset_hrefs_absolute()
for key, asset in (
(k, a)
for k, a in item.assets.items()
if (include is None or k in include)
and (exclude is None or k not in exclude)
if (not config.include or k in config.include)
and (not config.exclude or k not in config.exclude)
):
# TODO strategy should be auto-guessable
if asset_file_name_strategy == FileNameStrategy.FILE_NAME:
if config.asset_file_name_strategy == FileNameStrategy.FILE_NAME:
file_name = os.path.basename(URL(asset.href).path)
elif asset_file_name_strategy == FileNameStrategy.KEY:
elif config.asset_file_name_strategy == FileNameStrategy.KEY:
file_name = key + Path(asset.href).suffix
path = directory_as_path / file_name
if file_name in file_names:
Expand All @@ -212,7 +194,7 @@ async def download_item(
if isinstance(result, Exception):
exceptions.append(result)
if exceptions:
if warn_on_download_error:
if config.warn:
for exception in exceptions:
warnings.warn(str(exception), DownloadWarning)
if isinstance(exception, AssetDownloadError):
Expand All @@ -230,7 +212,7 @@ async def download_item(

if item_path:
item.set_self_href(str(item_path))
item.save_object(include_self_link=include_self_link)
item.save_object(include_self_link=True)
else:
item.set_self_href(None)

Expand All @@ -240,32 +222,14 @@ async def download_item_collection(
self,
item_collection: ItemCollection,
directory: PathLikeObject,
*,
make_directory: bool = False,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
item_collection_file_name: Optional[str] = "item-collection.json",
asset_file_name_strategy: FileNameStrategy = FileNameStrategy.FILE_NAME,
warn_on_download_error: bool = False,
config: Optional[Config] = None,
) -> ItemCollection:
"""Downloads an item collection and all of its assets to the given directory.
Args:
item_collection: The item collection to download
directory: The root location of the downloaded files
make_directory: If true and and the directory does not exist, create
the output directory before downloading
include: Asset keys to download. If not provided, all asset keys
will be downloaded.
exclude: Asset keys to not download. If not provided, all asset keys
will be downloaded.
item_collection_file_name: The name of the item collection file in the
directory. If not provided, the item collection will not be
written to the filesystem (only the assets will be downloaded).
asset_file_name_strategy: The :py:class:`FileNameStrategy` to use
for naming asset files
warn_on_download_error: Instead of raising any errors encountered
while downloading, warn and delete the asset from the item
config: Configuration for downloading the item
Returns:
ItemCollection: The :py:class:`~pystac.ItemCollection`, with the
Expand All @@ -274,9 +238,13 @@ async def download_item_collection(
Raises:
CantIncludeAndExclude: Raised if both include and exclude are not None.
"""
if config is None:
config = Config()
# Config validation happens at the download_item level

directory_as_path = Path(directory)
if not directory_as_path.exists():
if make_directory:
if config.make_directory:
directory_as_path.mkdir()
else:
raise FileNotFoundError(f"output directory does not exist: {directory}")
Expand All @@ -285,17 +253,15 @@ async def download_item_collection(
for item in item_collection.items:
# TODO what happens if items share ids?
item_directory = directory_as_path / item.id
item_config = config.copy()
item_config.make_directory = True
item_config.file_name = None
tasks.append(
asyncio.create_task(
self.download_item(
item=item,
directory=item_directory,
make_directory=True,
include=include,
exclude=exclude,
item_file_name=None,
asset_file_name_strategy=asset_file_name_strategy,
warn_on_download_error=warn_on_download_error,
config=item_config,
)
)
)
Expand All @@ -307,9 +273,9 @@ async def download_item_collection(
if exceptions:
raise DownloadError(exceptions)
item_collection.items = results
if item_collection_file_name:
if config.file_name:
item_collection.save_object(
dest_href=str(directory_as_path / item_collection_file_name)
dest_href=str(directory_as_path / config.file_name)
)
return item_collection

Expand Down
63 changes: 63 additions & 0 deletions src/stac_asset/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import copy
from dataclasses import dataclass, field
from typing import List, Optional

from .errors import CannotIncludeAndExclude
from .strategy import FileNameStrategy


@dataclass
class Config:
"""Configuration for downloading items and their assets."""

asset_file_name_strategy: FileNameStrategy = FileNameStrategy.FILE_NAME
"""The file name strategy to use when downloading assets."""

exclude: List[str] = field(default_factory=list)
"""Assets to exclude from the download.
Mutually exclusive with ``include``.
"""

include: List[str] = field(default_factory=list)
"""Assets to include in the download.
Mutually exclusive with ``exclude``.
"""

file_name: Optional[str] = None
"""The file name of the output item.
If not provided, the output item will not be saved.
"""

make_directory: bool = True
"""Whether to create the output directory.
If False, and the output directory does not exist, an error will be raised.
"""

warn: bool = False
"""When downloading, warn instead of erroring."""

s3_requester_pays: bool = False
"""If using the s3 client, enable requester pays."""

def validate(self) -> None:
"""Validates this configuration.
Raises:
CannotIncludeAndExclude: ``include`` and ``exclude`` are mutually exclusive
"""
if self.include and self.exclude:
raise CannotIncludeAndExclude(include=self.include, exclude=self.exclude)

def copy(self) -> Config:
"""Returns a deep copy of this config.
Returns:
Config: A deep copy of this config.
"""
return copy.deepcopy(self)
11 changes: 8 additions & 3 deletions src/stac_asset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ class DownloadWarning(Warning):
"""


class CantIncludeAndExclude(Exception):
class CannotIncludeAndExclude(Exception):
"""Raised if both include and exclude are passed to download."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(
self, include: List[str], exclude: List[str], *args: Any, **kwargs: Any
) -> None:
super().__init__(
"can't use include and exclude in the same download call", *args, **kwargs
"can't use include and exclude in the same download call: "
f"include={include}, exclude={exclude}",
*args,
**kwargs,
)


Expand Down
Loading

0 comments on commit 8a96b21

Please sign in to comment.