From 3c9a0d94e1b316671c6d58ae96952d53087fd8ef Mon Sep 17 00:00:00 2001 From: adam Date: Fri, 5 Jan 2024 12:46:15 -0700 Subject: [PATCH 1/4] Ruff for all test files --- .pre-commit-config.yaml | 20 +-- giftless/app.py | 11 +- giftless/auth/__init__.py | 82 +++++------ giftless/auth/allow_anon.py | 27 ++-- giftless/auth/identity.py | 45 ++++--- giftless/auth/jwt.py | 127 ++++++++++-------- giftless/config.py | 16 ++- giftless/error_handling.py | 9 +- giftless/exc.py | 2 +- giftless/representation.py | 14 +- giftless/schema.py | 20 +-- giftless/storage/__init__.py | 30 +++-- giftless/storage/amazon_s3.py | 14 +- giftless/storage/azure.py | 104 ++++++++------ giftless/storage/exc.py | 13 +- giftless/storage/google_cloud.py | 41 +++--- giftless/storage/local_storage.py | 51 ++++--- giftless/transfer/__init__.py | 28 ++-- giftless/transfer/basic_external.py | 32 +++-- giftless/transfer/basic_streaming.py | 60 +++++---- giftless/transfer/multipart.py | 40 +++--- giftless/transfer/types.py | 21 +-- giftless/util.py | 23 ++-- giftless/view.py | 45 ++++--- giftless/wsgi_entrypoint.py | 4 +- pyproject.toml | 19 +-- tests/auth/test_auth.py | 27 ++-- tests/auth/test_jwt.py | 71 +++++----- tests/conftest.py | 18 +-- tests/helpers.py | 26 ++-- tests/storage/__init__.py | 56 +++++--- tests/storage/test_amazon_s3.py | 12 +- tests/storage/test_azure.py | 18 ++- tests/storage/test_google_cloud.py | 25 ++-- tests/storage/test_local.py | 38 +++--- tests/test_batch_api.py | 85 ++++++------ tests/test_error_responses.py | 11 +- tests/test_middleware.py | 14 +- tests/test_schema.py | 19 ++- tests/transfer/conftest.py | 9 +- tests/transfer/test_basic_external_adapter.py | 17 +-- tests/transfer/test_module.py | 10 +- 42 files changed, 733 insertions(+), 621 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71bdb1b..aff3466 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,24 +9,16 @@ repos: # args: [--allow-multiple-documents] - id: trailing-whitespace - # FIXME: introduce after initial cleanup; it's going to take a lot - # of work. - # - repo: https://github.com/astral-sh/ruff-pre-commit - # rev: v0.1.8 - # hooks: - # - id: ruff - # args: [--fix, --exit-non-zero-on-fix] - # - id: ruff-format - - # FIXME: replace with ruff, eventually - - repo: https://github.com/psf/black - rev: 23.12.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.8 hooks: - - id: black + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format - repo: https://github.com/adamchainz/blacken-docs rev: 1.16.0 hooks: - id: blacken-docs additional_dependencies: [black==23.12.1] - args: [-l, '79', -t, py310] + args: [-l, '79', -t, 'py310'] diff --git a/giftless/app.py b/giftless/app.py index 05ede4b..8a61481 100644 --- a/giftless/app.py +++ b/giftless/app.py @@ -1,5 +1,4 @@ -"""Main Flask application initialization code -""" +"""Main Flask application initialization code.""" import logging import os from typing import Any @@ -14,7 +13,7 @@ def init_app(app: Flask | None = None, additional_config: Any = None) -> Flask: - """Flask app initialization""" + """Flask app initialization.""" if app is None: app = Flask(__name__) @@ -48,7 +47,7 @@ def init_app(app: Flask | None = None, additional_config: Any = None) -> Flask: def _load_middleware(flask_app: Flask) -> None: - """Load WSGI middleware classes from configuration""" + """Load WSGI middleware classes from configuration.""" log = logging.getLogger(__name__) wsgi_app = flask_app.wsgi_app middleware_config = flask_app.config["MIDDLEWARE"] @@ -58,6 +57,6 @@ def _load_middleware(flask_app: Flask) -> None: args = spec.get("args", []) kwargs = spec.get("kwargs", {}) wsgi_app = klass(wsgi_app, *args, **kwargs) - log.debug("Loaded middleware: %s(*%s, **%s)", klass, args, kwargs) + log.debug(f"Loaded middleware: {klass}(*{args}, **{kwargs}") - flask_app.wsgi_app = wsgi_app # type: ignore + flask_app.wsgi_app = wsgi_app diff --git a/giftless/auth/__init__.py b/giftless/auth/__init__.py index 5984148..2020c88 100644 --- a/giftless/auth/__init__.py +++ b/giftless/auth/__init__.py @@ -1,10 +1,9 @@ -"""Abstract authentication and authorization layer -""" +"""Abstract authentication and authorization layer.""" import abc import logging from collections.abc import Callable from functools import wraps -from typing import Any, Union +from typing import Any from flask import Flask, Request, current_app, g from flask import request as flask_request @@ -22,25 +21,29 @@ # Type for "Authenticator" # This can probably be made more specific once our protocol is more clear +# TODO @athornton: can it? class Authenticator(Protocol): - """Authenticators are callables (an object or function) that can authenticate - a request and provide an identity object + """Authenticators are callables (an object or function) that can + authenticate a request and provide an identity object. """ def __call__(self, request: Request) -> Identity | None: raise NotImplementedError( - "This is a protocol definition and should not be called directly" + "This is a protocol definition;" + " it should not be called directly." ) class PreAuthorizedActionAuthenticator(abc.ABC): """Pre-authorized action authenticators are special authenticators - that can also pre-authorize a follow-up action to the Git LFS server + that can also pre-authorize a follow-up action to the Git LFS + server. - They serve to both pre-authorize Git LFS actions and check these actions - are authorized as they come in. + They serve to both pre-authorize Git LFS actions and check these + actions are authorized as they come in. """ + @abc.abstractmethod def get_authz_query_params( self, identity: Identity, @@ -50,9 +53,9 @@ def get_authz_query_params( oid: str | None = None, lifetime: int | None = None, ) -> dict[str, str]: - """Authorize an action by adding credientaisl to the query string""" - return {} + """Authorize an action by adding credientaisl to the query string.""" + @abc.abstractmethod def get_authz_header( self, identity: Identity, @@ -62,11 +65,14 @@ def get_authz_header( oid: str | None = None, lifetime: int | None = None, ) -> dict[str, str]: - """Authorize an action by adding credentials to the request headers""" - return {} + """Authorize an action by adding credentials to the request headers.""" class Authentication: + """Wrap multiple Authenticators and default behaviors into an object to + manage authentication flow. + """ + def __init__( self, app: Flask | None = None, @@ -81,7 +87,7 @@ def __init__( self.init_app(app) def init_app(self, app: Flask) -> None: - """Initialize the Flask app""" + """Initialize the Flask app.""" app.config.setdefault("AUTH_PROVIDERS", []) app.config.setdefault("PRE_AUTHORIZED_ACTION_PROVIDER", None) @@ -99,7 +105,7 @@ def get_identity(self) -> Identity | None: return None def login_required(self, f: Callable) -> Callable: - """A typical Flask "login_required" view decorator""" + """Decorate the view; a typical Flask "login_required".""" @wraps(f) def decorated_function(*args: Any, **kwargs: Any) -> Any: @@ -111,10 +117,10 @@ def decorated_function(*args: Any, **kwargs: Any) -> Any: return decorated_function def no_identity_handler(self, f: Callable) -> Callable: - """Marker decorator for "unauthorized handler" function + """Marker decorator for "unauthorized handler" function. - This function will be called automatically if no authenticated identity was found - but is required. + This function will be called automatically if no authenticated + identity was found but is required. """ self._unauthorized_handler = f @@ -125,14 +131,14 @@ def decorated_func(*args: Any, **kwargs: Any) -> Any: return decorated_func def auth_failure(self) -> Any: - """Trigger an authentication failure""" + """Trigger an authentication failure.""" if self._unauthorized_handler: return self._unauthorized_handler() else: raise Unauthorized("User identity is required") def init_authenticators(self, reload: bool = False) -> None: - """Register an authenticator function""" + """Register an authenticator function.""" if reload: self._authenticators = None @@ -141,8 +147,9 @@ def init_authenticators(self, reload: bool = False) -> None: log = logging.getLogger(__name__) log.debug( - "Initializing authenticators, have %d authenticator(s) configured", - len(current_app.config["AUTH_PROVIDERS"]), + "Initializing authenticators," + f" have {len(current_app.config['AUTH_PROVIDERS'])}" + " authenticator(s) configured" ) self._authenticators = [ @@ -158,14 +165,14 @@ def init_authenticators(self, reload: bool = False) -> None: self.push_authenticator(self.preauth_handler) def push_authenticator(self, authenticator: Authenticator) -> None: - """Push an authenticator at the top of the stack""" + """Push an authenticator at the top of the stack.""" if self._authenticators is None: self._authenticators = [authenticator] return self._authenticators.insert(0, authenticator) def _authenticate(self) -> Identity | None: - """Call all registered authenticators until we find an identity""" + """Call all registered authenticators until we find an identity.""" self.init_authenticators() if self._authenticators is None: return self._default_identity @@ -174,9 +181,9 @@ def _authenticate(self) -> Identity | None: current_identity = authn(flask_request) if current_identity is not None: return current_identity - except Unauthorized as e: - # An authenticator is telling us the provided identity is invalid - # We should stop looking and return "no identity" + except Unauthorized as e: # noqa:PERF203 + # An authenticator is telling us the provided identity is + # invalid, so we should stop looking and return "no identity" log = logging.getLogger(__name__) log.debug(e.description) return None @@ -184,24 +191,23 @@ def _authenticate(self) -> Identity | None: return self._default_identity -def _create_authenticator(spec: Union[str, dict[str, Any]]) -> Authenticator: - """Instantiate an authenticator from configuration spec +def _create_authenticator(spec: str | dict[str, Any]) -> Authenticator: + """Instantiate an authenticator from configuration spec. - Configuration spec can be a string referencing a callable (e.g. mypackage.mymodule:callable) - in which case the callable will be returned as is; Or, a dict with 'factory' and 'options' - keys, in which case the factory callable is called with 'options' passed in as argument, and - the resulting callable is returned. + Configuration spec can be a string referencing a callable + (e.g. mypackage.mymodule:callable) in which case the callable will + be returned as is; Or, a dict with 'factory' and 'options' keys, + in which case the factory callable is called with 'options' passed + in as argument, and the resulting callable is returned. """ log = logging.getLogger(__name__) if isinstance(spec, str): - log.debug("Creating authenticator: %s", spec) + log.debug(f"Creating authenticator: {spec}") return get_callable(spec, __name__) - log.debug("Creating authenticator using factory: %s", spec["factory"]) - factory = get_callable( - spec["factory"], __name__ - ) # type: Callable[..., Authenticator] + log.debug(f"Creating authenticator using factory: {spec['factory']}") + factory = get_callable(spec["factory"], __name__) options = spec.get("options", {}) return factory(**options) diff --git a/giftless/auth/allow_anon.py b/giftless/auth/allow_anon.py index 7415548..319a056 100644 --- a/giftless/auth/allow_anon.py +++ b/giftless/auth/allow_anon.py @@ -1,17 +1,20 @@ -"""Dummy authentication module +"""Dummy authentication module. Always returns an `AnonymousUser` identity object. -Depending on whether "read only" or "read write" authentication was used, the -user is going to have read-only or read-write permissions on all objects. +Depending on whether "read only" or "read write" authentication was +used, the user is going to have read-only or read-write permissions on +all objects. -Only use this in production if you want your Giftless server to allow anonymous -access. Most likely, this is not what you want unless you are deploying in a -closed, 100% trusted network. +Only use this in production if you want your Giftless server to allow +anonymous access. Most likely, this is not what you want unless you +are deploying in a closed, 100% trusted network, or your server is +behind a proxy that handles authentication for the services it +manages. -If for some reason you want to allow anonymous users as a fall back (e.g. you -want to allow read-only access to anyone), be sure to load this authenticator -last. +If for some reason you want to allow anonymous users as a fallback +(e.g. you want to allow read-only access to anyone), be sure to load +this authenticator last. """ from typing import Any @@ -19,7 +22,7 @@ class AnonymousUser(DefaultIdentity): - """An anonymous user object""" + """An anonymous user object.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -28,14 +31,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def read_only(_: Any) -> AnonymousUser: - """Dummy authenticator that gives read-only permissions to everyone""" + """Give read-only permissions to everyone via AnonymousUser.""" user = AnonymousUser() user.allow(permissions={Permission.READ, Permission.READ_META}) return user def read_write(_: Any) -> AnonymousUser: - """Dummy authenticator that gives full permissions to everyone""" + """Give full permissions to everyone via AnonymousUser.""" user = AnonymousUser() user.allow(permissions=Permission.all()) return user diff --git a/giftless/auth/identity.py b/giftless/auth/identity.py index 08c42a2..8c669e1 100644 --- a/giftless/auth/identity.py +++ b/giftless/auth/identity.py @@ -1,11 +1,11 @@ +"""Objects to support Giftless's concept of users and permissions.""" from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum -from typing import Optional class Permission(Enum): - """System wide permissions""" + """System wide permissions.""" READ = "read" READ_META = "read-meta" @@ -17,20 +17,21 @@ def all(cls) -> set["Permission"]: PermissionTree = dict[ - Optional[str], dict[Optional[str], dict[Optional[str], set[Permission]]] + str | None, dict[str | None], dict[str | None, set[Permission]] ] class Identity(ABC): - """Base user identity object + """Base user identity object. - The goal of user objects is to contain some information about the user, and also - to allow checking if the user is authorized to perform some actions. + The goal of user objects is to contain some information about the + user, and also to allow checking if the user is authorized to + perform some actions. """ - name: Optional[str] = None - id: Optional[str] = None - email: Optional[str] = None + name: str | None = None + id: str | None = None + email: str | None = None @abstractmethod def is_authorized( @@ -38,21 +39,25 @@ def is_authorized( organization: str, repo: str, permission: Permission, - oid: Optional[str] = None, + oid: str | None = None, ) -> bool: - """Tell if user is authorized to perform an operation on an object / repo""" + """Determine whether user is authorized to perform an operation + on an object or repo. + """ def __repr__(self) -> str: return f"<{self.__class__.__name__} id:{self.id} name:{self.name}>" class DefaultIdentity(Identity): + """Default instantiable user identity class.""" + def __init__( self, - name: Optional[str] = None, - id: Optional[str] = None, - email: Optional[str] = None, - ): + name: str | None = None, + id: str | None = None, + email: str | None = None, + ) -> None: self.name = name self.id = id self.email = email @@ -62,10 +67,10 @@ def __init__( def allow( self, - organization: Optional[str] = None, - repo: Optional[str] = None, - permissions: Optional[set[Permission]] = None, - oid: Optional[str] = None, + organization: str | None = None, + repo: str | None = None, + permissions: set[Permission] | None = None, + oid: str | None = None, ) -> None: if permissions is None: self._allowed[organization][repo][oid] = set() @@ -77,7 +82,7 @@ def is_authorized( organization: str, repo: str, permission: Permission, - oid: Optional[str] = None, + oid: str | None = None, ) -> bool: if organization in self._allowed: if repo in self._allowed[organization]: diff --git a/giftless/auth/jwt.py b/giftless/auth/jwt.py index 5d9e33c..d200376 100644 --- a/giftless/auth/jwt.py +++ b/giftless/auth/jwt.py @@ -1,6 +1,8 @@ +"""Objects for JWT-based authentication.""" import logging from datetime import datetime, timedelta -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any import jwt from dateutil.tz import UTC @@ -13,12 +15,13 @@ class JWTAuthenticator(PreAuthorizedActionAuthenticator): - """Default JWT based authenticator + """Default JWT based authenticator. - This authenticator authenticates users by accepting a well-formed JWT token - (in the Authorization header as a Bearer type token). Tokens must be signed - by the right key, and also match in terms of audience, issuer and key ID if - configured, and of course have valid course expiry / not before times. + This authenticator authenticates users by accepting a well-formed + JWT token (in the Authorization header as a Bearer type + token). Tokens must be signed by the right key, and also match in + terms of audience, issuer and key ID if configured, and of course + have valid course expiry / not before times. Beyond authentication, JWT tokens may also include authorization payload in the "scopes" claim. @@ -44,13 +47,13 @@ class JWTAuthenticator(PreAuthorizedActionAuthenticator): are granting access to all objects in the repository {subscope} can be 'metadata' or omitted entirely. If 'metadata' is - specified, the scope does not grant access to actual files, - but to metadata only - e.g. objects can be verified to exist - but not downloaded. + specified, the scope does not grant access to actual + files, but to metadata only - e.g. objects can be + verified to exist but not downloaded. - {actions} is a comma separated list of allowed actions. Actions can be - 'read', 'write' or 'verify'. If omitted or replaced with a - '*', all actions are permitted. + {actions} is a comma separated list of allowed actions. Actions + can be 'read', 'write' or 'verify'. If omitted or + replaced with a '*', all actions are permitted. Some examples of decoded tokens (added comments are not valid JSON): @@ -63,10 +66,12 @@ class JWTAuthenticator(PreAuthorizedActionAuthenticator): "email": "user@example.com", // Optional, user's email "scopes": [ // read a specific object - "obj:datopian/somerepo/6adada03e86b154be00e25f288fcadc27aef06c47f12f88e3e1985c502803d1b:read", + "obj:datopian/somerepo/6adada03e86b154be00e25f288f" + "cadc27aef06c47f12f88e3e1985c502803d1b:read", // read the same object, but do not limit to a specific prefix - "obj:6adada03e86b154be00e25f288fcadc27aef06c47f12f88e3e1985c502803d1b:read", + "obj:6adada03e86b154be00e25f288f" + "cadc27aef06c47f12f88e3e1985c502803d1b:read", // full access to all objects in a repo "obj:datopian/my-repo/*", @@ -82,15 +87,17 @@ class JWTAuthenticator(PreAuthorizedActionAuthenticator): Typically a token will include a single scope - but multiple scopes are allowed. - This authenticator will pass on the attempt to authenticate if no token was - provided, or it is not a JWT token, or if a key ID is configured and a - provided JWT token does not have the matching "kid" head claim (this allows - chaining multiple JWT authenticators if needed). + This authenticator will pass on the attempt to authenticate if no + token was provided, or it is not a JWT token, or if a key ID is + configured and a provided JWT token does not have the matching + "kid" head claim (this allows chaining multiple JWT authenticators + if needed). - However, if a matching but invalid token was provided, a 401 Unauthorized - response will be returned. "Invalid" means a token with audience or issuer - mismatch (if configured), an expiry time in the past or an "not before" - time in the future, or, of course, an invalid signature. + However, if a matching but invalid token was provided, a 401 + Unauthorized response will be returned. "Invalid" means a token + with audience or issuer mismatch (if configured), an expiry time + in the past or an "not before" time in the future, or, of course, + an invalid signature. The "leeway" parameter allows for providing a leeway / grace time to be considered when checking expiry times, to cover for clock skew between @@ -104,16 +111,16 @@ class JWTAuthenticator(PreAuthorizedActionAuthenticator): def __init__( self, - private_key: Optional[Union[str, bytes]] = None, + private_key: str | bytes | None = None, default_lifetime: int = DEFAULT_LIFETIME, algorithm: str = DEFAULT_ALGORITHM, - public_key: Optional[str] = None, - issuer: Optional[str] = None, - audience: Optional[str] = None, + public_key: str | None = None, + issuer: str | None = None, + audience: str | None = None, leeway: int = DEFAULT_LEEWAY, - key_id: Optional[str] = None, - basic_auth_user: Optional[str] = DEFAULT_BASIC_AUTH_USER, - ): + key_id: str | None = None, + basic_auth_user: str | None = DEFAULT_BASIC_AUTH_USER, + ) -> None: self.algorithm = algorithm self.default_lifetime = default_lifetime self.leeway = leeway @@ -123,7 +130,7 @@ def __init__( self.audience = audience self.key_id = key_id self.basic_auth_user = basic_auth_user - self._verification_key: Union[str, bytes, None] = None # lazy loaded + self._verification_key: str | bytes | None = None # lazy loaded self._log = logging.getLogger(__name__) def __call__(self, request: Request) -> Identity | None: @@ -146,11 +153,11 @@ def _generate_token_for_action( identity: Identity, org: str, repo: str, - actions: Optional[set[str]] = None, - oid: Optional[str] = None, - lifetime: Optional[int] = None, + actions: set[str] | None = None, + oid: str | None = None, + lifetime: int | None = None, ) -> str: - """Generate a JWT token authorizing the specific requested action""" + """Generate a JWT token authorizing the specific requested action.""" token_payload: dict[str, Any] = {"sub": identity.id} if self.issuer: token_payload["iss"] = self.issuer @@ -178,20 +185,23 @@ def _generate_token_for_action( def _generate_action_scopes( org: str, repo: str, - actions: Optional[set[str]] = None, - oid: Optional[str] = None, + actions: set[str] | None = None, + oid: str | None = None, ) -> str: - """Generate token scopes based on target object and actions""" + """Generate token scopes based on target object and actions.""" if oid is None: oid = "*" obj_id = f"{org}/{repo}/{oid}" return str(Scope("obj", obj_id, actions)) def _generate_token(self, **kwargs: Any) -> str: - """Generate a JWT token that can be used later to authenticate a request""" + """Generate a JWT token that can be used later to authenticate + a request. + """ if not self.private_key: raise ValueError( - "This authenticator is not configured to generate tokens; Set private_key to fix" + "This authenticator is not configured to generate tokens;" + " set private_key to fix" ) payload: dict[str, Any] = { @@ -227,7 +237,7 @@ def _generate_token(self, **kwargs: Any) -> str: return token.decode("ascii") def _authenticate(self, request: Request) -> dict[str, Any] | None: - """Authenticate a request""" + """Authenticate a request.""" token = self._get_token_from_headers(request) if token is None: token = self._get_token_from_qs(request) @@ -253,14 +263,16 @@ def _authenticate(self, request: Request) -> dict[str, Any] | None: except jwt.PyJWTError as e: raise Unauthorized( f"Expired or otherwise invalid JWT token ({e!s})" - ) + ) from None def _get_token_from_headers(self, request: Request) -> str | None: - """Extract JWT token from HTTP Authorization header + """Extract JWT token from HTTP Authorization header. - This will first try to obtain a Bearer token. If none is found but we have a 'Basic' Authorization header, - and basic auth JWT payload has not been disabled, and the provided username matches the configured JWT token - username, we will try to use the provided password as if it was a JWT token. + This will first try to obtain a Bearer token. If none is found + but we have a 'Basic' Authorization header, and basic auth JWT + payload has not been disabled, and the provided username + matches the configured JWT token username, we will try to use + the provided password as if it were a JWT token. """ header = request.headers.get("Authorization") if not header: @@ -288,8 +300,8 @@ def _get_token_from_headers(self, request: Request) -> str | None: return None @staticmethod - def _get_token_from_qs(request: Request) -> Optional[str]: - """Get JWT token from the query string""" + def _get_token_from_qs(request: Request) -> str | None: + """Get JWT token from the query string.""" return request.args.get("jwt") def _get_identity(self, jwt_payload: dict[str, Any]) -> Identity: @@ -307,7 +319,9 @@ def _get_identity(self, jwt_payload: dict[str, Any]) -> Identity: return identity def _parse_scope(self, scope_str: str) -> dict[str, Any]: - """Parse a scope string and convert it into arguments for Identity.allow()""" + """Parse a scope string and convert it into arguments for + Identity.allow(). + """ scope = Scope.from_string(scope_str) if scope.entity_type != "obj": return {} @@ -339,7 +353,7 @@ def _parse_scope(self, scope_str: str) -> dict[str, Any]: @staticmethod def _parse_scope_permissions(scope: "Scope") -> set[Permission]: - """Extract granted permissions from scope object""" + """Extract granted permissions from scope object.""" permissions_map = { "read": {Permission.READ, Permission.READ_META}, "write": {Permission.WRITE}, @@ -358,8 +372,8 @@ def _parse_scope_permissions(scope: "Scope") -> set[Permission]: return permissions - def _get_verification_key(self) -> Union[str, bytes]: - """Get the key used for token verification, based on algorithm""" + def _get_verification_key(self) -> str | bytes: + """Get the key used for token verification, based on algorithm.""" if self._verification_key is None: if self.algorithm[0:2] == "HS": self._verification_key = self.private_key @@ -368,14 +382,14 @@ def _get_verification_key(self) -> Union[str, bytes]: if self._verification_key is None: raise ValueError( - "No private or public key have been set, can't verify requests" + "No private or public key is set; cannot verify requests" ) return self._verification_key class Scope: - """Scope object""" + """Scope object.""" def __init__( self, @@ -393,7 +407,7 @@ def __repr__(self) -> str: return f"" def __str__(self) -> str: - """Convert scope to a string""" + """Convert scope to a string.""" parts = [self.entity_type] entity_ref = self.entity_ref if self.entity_ref != "*" else None subscobe = self.subscope if self.subscope != "*" else None @@ -420,7 +434,7 @@ def __str__(self) -> str: @classmethod def from_string(cls, scope_str: str) -> "Scope": - """Create a scope object from string""" + """Create a scope object from string.""" parts = scope_str.split(":") if len(parts) < 1: raise ValueError("Scope string should have at least 1 part") @@ -445,11 +459,12 @@ def _parse_actions(cls, actions_str: str) -> set[str]: def factory(**options: Any) -> JWTAuthenticator: + """Build a JWT Authenticator from supplied options.""" for key_type in ("private_key", "public_key"): file_opt = f"{key_type}_file" try: if options[file_opt]: - with open(options[file_opt]) as f: + with Path(options[file_opt]).open() as f: options[key_type] = f.read() options.pop(file_opt) except KeyError: diff --git a/giftless/config.py b/giftless/config.py index 66bddd7..149716b 100644 --- a/giftless/config.py +++ b/giftless/config.py @@ -1,11 +1,11 @@ -"""Configuration handling helper functions and default configuration -""" +"""Configuration handling helper functions and default configuration.""" import os +from pathlib import Path from typing import Any import yaml from dotenv import load_dotenv -from figcan import Configuration, Extensible # type:ignore +from figcan import Configuration, Extensible from flask import Flask ENV_PREFIX = "GIFTLESS_" @@ -17,7 +17,9 @@ "factory": "giftless.transfer.basic_streaming:factory", "options": Extensible( { - "storage_class": "giftless.storage.local_storage:LocalStorage", + "storage_class": ( + "giftless.storage.local_storage:LocalStorage" + ), "storage_options": Extensible({"path": "lfs-storage"}), "action_lifetime": 900, } @@ -50,7 +52,7 @@ def configure(app: Flask, additional_config: dict | None = None) -> Flask: - """Configure a Flask app using Figcan managed configuration object""" + """Configure a Flask app using Figcan managed configuration object.""" config = _compose_config(additional_config) app.config.update(config) return app @@ -59,14 +61,14 @@ def configure(app: Flask, additional_config: dict | None = None) -> Flask: def _compose_config( additional_config: dict[str, Any] | None = None, ) -> Configuration: - """Compose configuration object from all available sources""" + """Compose configuration object from all available sources.""" config = Configuration(default_config) environ = dict( os.environ ) # Copy the environment as we're going to change it if environ.get(f"{ENV_PREFIX}CONFIG_FILE"): - with open(environ[f"{ENV_PREFIX}CONFIG_FILE"]) as f: + with Path(environ[f"{ENV_PREFIX}CONFIG_FILE"]).open() as f: config_from_file = yaml.safe_load(f) config.apply(config_from_file) environ.pop(f"{ENV_PREFIX}CONFIG_FILE") diff --git a/giftless/error_handling.py b/giftless/error_handling.py index c9852ad..91ee8d3 100644 --- a/giftless/error_handling.py +++ b/giftless/error_handling.py @@ -1,6 +1,7 @@ -"""Handle errors according to the Git LFS spec +"""Handle errors according to the Git LFS spec. -See https://github.com/git-lfs/git-lfs/blob/master/docs/api/batch.md#response-errors +See https://github.com/git-lfs/git-lfs/blob/master/docs\ +/api/batch.md#response-errors """ from flask import Flask, Response from werkzeug.exceptions import default_exceptions @@ -9,6 +10,8 @@ class ApiErrorHandler: + """Handler to send JSON response for errors.""" + def __init__(self, app: Flask | None = None) -> None: if app: self.init_app(app) @@ -19,7 +22,7 @@ def init_app(self, app: Flask) -> None: @classmethod def error_as_json(cls, ex: Exception) -> Response: - """Handle errors by returning a JSON response""" + """Handle errors by returning a JSON response.""" code = ex.code if hasattr(ex, "code") else 500 data = {"message": str(ex)} diff --git a/giftless/exc.py b/giftless/exc.py index 799e8d8..516144c 100644 --- a/giftless/exc.py +++ b/giftless/exc.py @@ -1,4 +1,4 @@ -"""Map Werkzueg exceptions to domain specific exceptions +"""Map Werkzueg exceptions to domain-specific exceptions. These exceptions should be used in all domain (non-Flask specific) code to avoid tying in to Flask / Werkzueg where it is not needed. diff --git a/giftless/representation.py b/giftless/representation.py index 4f21c59..dee90b0 100644 --- a/giftless/representation.py +++ b/giftless/representation.py @@ -1,9 +1,10 @@ -"""Representations define how to render a response for a given content-type +"""Representations define how to render a response for a given content-type. Most commonly this will convert data returned by views into JSON or a similar format. -See http://flask-classful.teracy.org/#adding-resource-representations-get-real-classy-and-put-on-a-top-hat +See http://flask-classful.teracy.org/\ +#adding-resource-representations-get-real-classy-and-put-on-a-top-hat """ import json from datetime import datetime @@ -14,9 +15,12 @@ GIT_LFS_MIME_TYPE = "application/vnd.git-lfs+json" +# TODO @athornton: this, like the schemas, seems like something Pydantic +# does really well, but probably a big job. + class CustomJsonEncoder(json.JSONEncoder): - """Custom JSON encoder that can support some additional required types""" + """Custom JSON encoder that supports some additional required types.""" def default(self, o: Any) -> Any: if isinstance(o, datetime): @@ -30,13 +34,13 @@ def output_json( headers: dict[str, str] | None = None, content_type: str = "application/json", ) -> Response: + """Set appropriate Content-Type header for JSON response.""" dumped = json.dumps(data, cls=CustomJsonEncoder) if headers: headers.update({"Content-Type": content_type}) else: headers = {"Content-Type": content_type} - response = make_response(dumped, code, headers) - return response + return make_response(dumped, code, headers) output_git_lfs_json = partial(output_json, content_type=GIT_LFS_MIME_TYPE) diff --git a/giftless/schema.py b/giftless/schema.py index ea09c91..0e06cd1 100644 --- a/giftless/schema.py +++ b/giftless/schema.py @@ -1,5 +1,4 @@ -"""Schema for Git LFS APIs -""" +"""Schema for Git LFS APIs.""" from enum import Enum from typing import Any @@ -10,22 +9,25 @@ ma = Marshmallow() +# TODO @athornton: probably a big job but it feels like this is what Pydantic +# is for. + class Operation(Enum): - """Batch operations""" + """Batch operations.""" upload = "upload" download = "download" -class RefSchema(ma.Schema): # type: ignore - """ref field schema""" +class RefSchema(ma.Schema): + """ref field schema.""" name = fields.String(required=True) -class ObjectSchema(ma.Schema): # type: ignore - """object field schema""" +class ObjectSchema(ma.Schema): + """object field schema.""" oid = fields.String(required=True) size = fields.Integer(required=True, validate=validate.Range(min=0)) @@ -46,7 +48,9 @@ def set_extra_fields( return {"extra": extra, **rest} -class BatchRequest(ma.Schema): # type: ignore +class BatchRequest(ma.Schema): + """batch request schema.""" + operation = EnumField(Operation, required=True) transfers = fields.List(fields.String, required=False, missing=["basic"]) ref = fields.Nested(RefSchema, required=False) diff --git a/giftless/storage/__init__.py b/giftless/storage/__init__.py index f91afa3..5ddc492 100644 --- a/giftless/storage/__init__.py +++ b/giftless/storage/__init__.py @@ -1,27 +1,32 @@ +"""Storage base classes.""" import mimetypes from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, BinaryIO, Optional +from typing import Any, BinaryIO from . import exc +# TODO @athornton: Think about refactoring this; some deduplication of +# `verify_object`, at least. + class VerifiableStorage(ABC): - """A storage backend that supports object verification API + """A storage backend that supports object verification API. All streaming backends should be 'verifiable'. """ @abstractmethod def verify_object(self, prefix: str, oid: str, size: int) -> bool: - """Check that object exists and has the right size + """Check that object exists and has the right size. - This method should not throw an error if the object does not exist, but return False + This method should not throw an error if the object does not + exist, but return False. """ class StreamingStorage(VerifiableStorage, ABC): - """Interface for streaming storage adapters""" + """Interface for streaming storage adapters.""" @abstractmethod def get(self, prefix: str, oid: str) -> Iterable[bytes]: @@ -43,15 +48,15 @@ def get_mime_type(self, prefix: str, oid: str) -> str: return "application/octet-stream" def verify_object(self, prefix: str, oid: str, size: int) -> bool: - """Verify that an object exists""" + """Verify that an object exists and has the right size.""" try: return self.get_size(prefix, oid) == size - except exc.ObjectNotFound: + except exc.ObjectNotFoundError: return False class ExternalStorage(VerifiableStorage, ABC): - """Interface for streaming storage adapters""" + """Interface for streaming storage adapters.""" @abstractmethod def get_upload_action( @@ -84,13 +89,16 @@ def get_size(self, prefix: str, oid: str) -> int: pass def verify_object(self, prefix: str, oid: str, size: int) -> bool: + """Verify that object exists and has the correct size.""" try: return self.get_size(prefix, oid) == size - except exc.ObjectNotFound: + except exc.ObjectNotFoundError: return False class MultipartStorage(VerifiableStorage, ABC): + """Base class for storage that supports multipart uploads.""" + @abstractmethod def get_multipart_actions( self, @@ -123,11 +131,13 @@ def get_size(self, prefix: str, oid: str) -> int: pass def verify_object(self, prefix: str, oid: str, size: int) -> bool: + """Verify that object exists and has the correct size.""" try: return self.get_size(prefix, oid) == size - except exc.ObjectNotFound: + except exc.ObjectNotFoundError: return False def guess_mime_type_from_filename(filename: str) -> str | None: + """Based on the filename, guess what MIME type it is.""" return mimetypes.guess_type(filename)[0] diff --git a/giftless/storage/amazon_s3.py b/giftless/storage/amazon_s3.py index f96df6a..d3e95e4 100644 --- a/giftless/storage/amazon_s3.py +++ b/giftless/storage/amazon_s3.py @@ -1,3 +1,4 @@ +"""Amazon S3 backend.""" import base64 import binascii import posixpath @@ -8,7 +9,7 @@ import botocore from giftless.storage import ExternalStorage, StreamingStorage -from giftless.storage.exc import ObjectNotFound +from giftless.storage.exc import ObjectNotFoundError from giftless.util import safe_filename @@ -29,7 +30,7 @@ def __init__( def get(self, prefix: str, oid: str) -> Iterable[bytes]: if not self.exists(prefix, oid): - raise ObjectNotFound() + raise ObjectNotFoundError result: Iterable[bytes] = self._s3_object(prefix, oid).get()["Body"] return result @@ -50,7 +51,7 @@ def upload_callback(size: int) -> None: def exists(self, prefix: str, oid: str) -> bool: try: self.get_size(prefix, oid) - except ObjectNotFound: + except ObjectNotFoundError: return False return True @@ -59,9 +60,8 @@ def get_size(self, prefix: str, oid: str) -> int: result: int = self._s3_object(prefix, oid).content_length except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "404": - raise ObjectNotFound() - else: - raise e + raise ObjectNotFoundError from None + raise return result def get_upload_action( @@ -135,7 +135,7 @@ def get_download_action( } def _get_blob_path(self, prefix: str, oid: str) -> str: - """Get the path to a blob in storage""" + """Get the path to a blob in storage.""" if not self.path_prefix: storage_prefix = "" elif self.path_prefix[0] == "/": diff --git a/giftless/storage/azure.py b/giftless/storage/azure.py index ca0a84d..40775f8 100644 --- a/giftless/storage/azure.py +++ b/giftless/storage/azure.py @@ -1,10 +1,10 @@ +"""Azure cloud storage backend.""" import base64 import logging import posixpath -from collections import namedtuple from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import IO, Any, Optional +from typing import IO, Any, NamedTuple from urllib.parse import urlencode from xml.sax.saxutils import escape as xml_escape @@ -23,9 +23,16 @@ guess_mime_type_from_filename, ) -from .exc import ObjectNotFound +from .exc import ObjectNotFoundError + + +class Block(NamedTuple): + """Convenience wrapper for Azure block.""" + + id: int + start: int + size: int -Block = namedtuple("Block", ["id", "start", "size"]) _log = logging.getLogger(__name__) @@ -41,7 +48,7 @@ def __init__( self, connection_string: str, container_name: str, - path_prefix: Optional[str] = None, + path_prefix: str | None = None, enable_content_digest: bool = True, **_: Any, ) -> None: @@ -60,7 +67,7 @@ def get(self, prefix: str, oid: str) -> Iterable[bytes]: try: return blob_client.download_blob().chunks() except ResourceNotFoundError: - raise ObjectNotFound("Object does not exist") + raise ObjectNotFoundError("Object does not exist") from None def put(self, prefix: str, oid: str, data_stream: IO[bytes]) -> int: blob_client = self.blob_svc_client.get_blob_client( @@ -73,9 +80,9 @@ def put(self, prefix: str, oid: str, data_stream: IO[bytes]) -> int: def exists(self, prefix: str, oid: str) -> bool: try: self.get_size(prefix, oid) - return True - except ObjectNotFound: + except ObjectNotFoundError: return False + return True def get_size(self, prefix: str, oid: str) -> int: try: @@ -84,9 +91,9 @@ def get_size(self, prefix: str, oid: str) -> int: blob=self._get_blob_path(prefix, oid), ) props = blob_client.get_blob_properties() - return props.size except ResourceNotFoundError: - raise ObjectNotFound("Object does not exist") + raise ObjectNotFoundError("Object does not exist") from None + return props.size def get_mime_type(self, prefix: str, oid: str) -> str: try: @@ -100,9 +107,9 @@ def get_mime_type(self, prefix: str, oid: str) -> str: ) if mime_type is None: return "application/octet-stream" - return str(mime_type) except ResourceNotFoundError: - raise ObjectNotFound("Object does not exist") + raise ObjectNotFoundError("Object does not exist") from None + return str(mime_type) def get_upload_action( self, @@ -110,7 +117,7 @@ def get_upload_action( oid: str, size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: filename = extra.get("filename") if extra else None headers = { @@ -142,7 +149,7 @@ def get_download_action( oid: str, size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: filename = extra.get("filename") if extra else None disposition = ( @@ -173,9 +180,9 @@ def get_multipart_actions( size: int, part_size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Get actions for a multipart upload""" + """Get actions for a multipart upload.""" blocks = _calculate_blocks(size, part_size) uncommitted = self._get_uncommitted_blocks(prefix, oid, blocks) @@ -195,9 +202,8 @@ def get_multipart_actions( if b.id not in uncommitted ] _log.info( - "There are %d uncommitted blocks pre-uploaded; %d parts still need to be uploaded", - len(uncommitted), - len(parts), + f"There are {len(uncommitted)} uncommitted blocks pre-uploaded;" + f" {len(parts)} parts still need to be uploaded" ) commit_body = self._create_commit_body(blocks) reply: dict[str, Any] = { @@ -229,7 +235,7 @@ def get_multipart_actions( return reply def _get_blob_path(self, prefix: str, oid: str) -> str: - """Get the path to a blob in storage""" + """Get the path to a blob in storage.""" if not self.path_prefix: storage_prefix = "" elif self.path_prefix[0] == "/": @@ -243,8 +249,8 @@ def _get_signed_url( prefix: str, oid: str, expires_in: int, - filename: Optional[str] = None, - disposition: Optional[str] = None, + filename: str | None = None, + disposition: str | None = None, **permissions: bool, ) -> str: blob_name = self._get_blob_path(prefix, oid) @@ -277,44 +283,48 @@ def _get_signed_url( blob_name=blob_name, credential=sas_token, ) - return blob_client.url # type: ignore + return blob_client.url def _get_uncommitted_blocks( self, prefix: str, oid: str, blocks: list[Block] ) -> dict[int, int]: - """Get list of uncommitted blocks from the server""" + """Get list of uncommitted blocks from the server.""" blob_client = self.blob_svc_client.get_blob_client( container=self.container_name, blob=self._get_blob_path(prefix, oid), ) try: - committed_blocks, uncommitted_blocks = blob_client.get_block_list( - block_list_type="all" - ) + ( + committed_blocks, + uncommitted_blocks, + ) = blob_client.get_block_list(block_list_type="all") except ResourceNotFoundError: return {} if committed_blocks: _log.warning( - f"Committed blocks found for {oid}, this is unexpected state; restarting upload" + f"Unexpected state: Committed blocks found for {oid};" + " state; restarting upload" ) blob_client.delete_blob() return {} try: - # NOTE: The Azure python library already does ID base64 decoding for us, so we only case to int here + # NOTE: The Azure python library already does ID base64 + # decoding for us, so we only case to int here existing_blocks = { int(b["id"]): b["size"] for b in uncommitted_blocks } except ValueError: _log.warning( - "Some uncommitted blocks have unexpected ID format; restarting upload" + "Some uncommitted blocks have unexpected ID format;" + " restarting upload" ) return {} _log.debug( - "Found %d existing uncommitted blocks on server", - len(existing_blocks), + f"Found {len(existing_blocks)} existing uncommitted blocks" + " on server" ) # Verify that existing blocks are the same as what we plan to upload @@ -324,7 +334,8 @@ def _get_uncommitted_blocks( and existing_blocks[block.id] != block.size ): _log.warning( - "Uncommitted block size does not match our plan, restating upload" + "Uncommitted block size does not match our plan;" + " restarting upload" ) blob_client.delete_blob() return {} @@ -334,7 +345,7 @@ def _get_uncommitted_blocks( def _create_part_request( self, base_url: str, block: Block, expires_in: int ) -> dict[str, Any]: - """Create the part request object for a block""" + """Create the part request object for a block.""" block_id = self._encode_block_id(block.id) part = { "href": f"{base_url}&comp=block&blockid={block_id}", @@ -349,12 +360,17 @@ def _create_part_request( return part def _create_commit_body(self, blocks: list[Block]) -> str: - """Create the body for a 'Put Blocks' request we use in commit + """Create the body for a 'Put Blocks' request we use in commit. - NOTE: This is a simple XML construct, so we don't import / depend on XML construction API - here. If this ever gets complex, it may be a good idea to rely on lxml or similar. + NOTE: This is a simple XML construct, so we don't import / + depend on XML construction API here. If this ever gets + complex, it may be a good idea to rely on lxml or similar. """ - return '{}'.format( + tpl = ( + '{}' + "" + ) + return tpl.format( "".join( [ "{}".format( @@ -367,20 +383,24 @@ def _create_commit_body(self, blocks: list[Block]) -> str: @classmethod def _encode_block_id(cls, b_id: int) -> str: - """Encode a block ID in the manner expected by the Azure API""" + """Encode a block ID in the manner expected by the Azure API.""" return base64.b64encode( str(b_id).zfill(cls._PART_ID_BYTE_SIZE).encode("ascii") ).decode("ascii") def _calculate_blocks(file_size: int, part_size: int) -> list[Block]: - """Calculate the list of blocks in a blob + """Calculate the list of blocks in a blob. >>> _calculate_blocks(30, 10) - [Block(id=0, start=0, size=10), Block(id=1, start=10, size=10), Block(id=2, start=20, size=10)] + [Block(id=0, start=0, size=10), + Block(id=1, start=10, size=10), + Block(id=2, start=20, size=10)] >>> _calculate_blocks(28, 10) - [Block(id=0, start=0, size=10), Block(id=1, start=10, size=10), Block(id=2, start=20, size=8)] + [Block(id=0, start=0, size=10), + Block(id=1, start=10, size=10), + Block(id=2, start=20, size=8)] >>> _calculate_blocks(7, 10) [Block(id=0, start=0, size=7)] diff --git a/giftless/storage/exc.py b/giftless/storage/exc.py index ef807cb..d268bd0 100644 --- a/giftless/storage/exc.py +++ b/giftless/storage/exc.py @@ -1,9 +1,8 @@ -"""Storage related errors -""" +"""Storage related errors.""" class StorageError(RuntimeError): - """Base class for storage errors""" + """Base class for storage errors.""" code: int | None = None @@ -11,9 +10,13 @@ def as_dict(self) -> dict[str, str | int | None]: return {"message": str(self), "code": self.code} -class ObjectNotFound(StorageError): +class ObjectNotFoundError(StorageError): + """No such object exists.""" + code = 404 -class InvalidObject(StorageError): +class InvalidObjectError(StorageError): + """Request is syntactically OK, but invalid (wrong fields, usually).""" + code = 422 diff --git a/giftless/storage/google_cloud.py b/giftless/storage/google_cloud.py index 5c55889..bb25445 100644 --- a/giftless/storage/google_cloud.py +++ b/giftless/storage/google_cloud.py @@ -1,9 +1,12 @@ +"""Google Cloud Storage backend supporting direct-to-cloud transfers via +signed URLs. +""" import base64 import io import json import posixpath from datetime import timedelta -from typing import Any, BinaryIO, Union +from typing import Any, BinaryIO import google.auth from google.auth import impersonated_credentials @@ -12,12 +15,12 @@ from giftless.storage import ExternalStorage, StreamingStorage -from .exc import ObjectNotFound +from .exc import ObjectNotFoundError class GoogleCloudStorage(StreamingStorage, ExternalStorage): """Google Cloud Storage backend supporting direct-to-cloud - transfers. + transfers via signed URLs. """ def __init__( @@ -32,11 +35,11 @@ def __init__( ) -> None: self.bucket_name = bucket_name self.path_prefix = path_prefix - self.credentials: Union[ - service_account.Credentials, - impersonated_credentials.Credentials, - None, - ] = self._load_credentials(account_key_file, account_key_base64) + self.credentials: ( + service_account.Credentials + | impersonated_credentials.Credentials + | None + ) = self._load_credentials(account_key_file, account_key_base64) self.storage_client = storage.Client( project=project_name, credentials=self.credentials ) @@ -52,7 +55,7 @@ def get(self, prefix: str, oid: str) -> BinaryIO: bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.get_blob(self._get_blob_path(prefix, oid)) if blob is None: - raise ObjectNotFound("Object does not exist") + raise ObjectNotFoundError("Object does not exist") stream = io.BytesIO() blob.download_to_file(stream) stream.seek(0) @@ -67,14 +70,14 @@ def put(self, prefix: str, oid: str, data_stream: BinaryIO) -> int: def exists(self, prefix: str, oid: str) -> bool: bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.blob(self._get_blob_path(prefix, oid)) - return blob.exists() # type: ignore + return blob.exists() def get_size(self, prefix: str, oid: str) -> int: bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.get_blob(self._get_blob_path(prefix, oid)) if blob is None: - raise ObjectNotFound("Object does not exist") - return blob.size # type: ignore + raise ObjectNotFoundError("Object does not exist") + return blob.size def get_upload_action( self, @@ -126,7 +129,7 @@ def get_download_action( } def _get_blob_path(self, prefix: str, oid: str) -> str: - """Get the path to a blob in storage""" + """Get the path to a blob in storage.""" if not self.path_prefix: storage_prefix = "" elif self.path_prefix[0] == "/": @@ -166,22 +169,22 @@ def _get_signed_url( def _load_credentials( account_key_file: str | None, account_key_base64: str | None ) -> service_account.Credentials | None: - """Load Google Cloud credentials from passed configuration""" + """Load Google Cloud credentials from passed configuration.""" if account_key_file and account_key_base64: raise ValueError( - "Provide either account_key_file or account_key_base64 but not both" + "Provide either account_key_file or account_key_base64" + " but not both" ) - elif account_key_file: + if account_key_file: return service_account.Credentials.from_service_account_file( account_key_file ) - elif account_key_base64: + if account_key_base64: account_info = json.loads(base64.b64decode(account_key_base64)) return service_account.Credentials.from_service_account_info( account_info ) - else: - return None # Will use Workload Identity if available + return None # Will use Workload Identity if available def _get_workload_identity_credentials( self, expires_in: int diff --git a/giftless/storage/local_storage.py b/giftless/storage/local_storage.py index f602f85..b324a45 100644 --- a/giftless/storage/local_storage.py +++ b/giftless/storage/local_storage.py @@ -1,6 +1,9 @@ -import os +"""Local storage implementation, for development/testing or small-scale +deployments. +""" import shutil -from typing import Any, BinaryIO, Optional +from pathlib import Path +from typing import Any, BinaryIO from flask import Flask @@ -9,14 +12,15 @@ class LocalStorage(StreamingStorage, MultipartStorage, ViewProvider): - """Local storage implementation + """Local storage implementation. - This storage backend works by storing files in the local file system. - While it can be used in production, large scale deployment will most likely - want to use a more scalable solution such as one of the cloud storage backends. + This storage backend works by storing files in the local file + system. While it can be used in production, large scale + deployment will most likely want to use a more scalable solution + such as one of the cloud storage backends. """ - def __init__(self, path: Optional[str] = None, **_: Any) -> None: + def __init__(self, path: str | None = None, **_: Any) -> None: if path is None: path = "lfs-storage" self.path = path @@ -24,31 +28,33 @@ def __init__(self, path: Optional[str] = None, **_: Any) -> None: def get(self, prefix: str, oid: str) -> BinaryIO: path = self._get_path(prefix, oid) - if os.path.isfile(path): - return open(path, "br") + if path.is_file(): + return path.open("br") else: - raise exc.ObjectNotFound("Object was not found") + raise exc.ObjectNotFoundError(f"Object {path} was not found") def put(self, prefix: str, oid: str, data_stream: BinaryIO) -> int: path = self._get_path(prefix, oid) - directory = os.path.dirname(path) + directory = path.parent self._create_path(directory) - with open(path, "bw") as dest: + with path.open("bw") as dest: shutil.copyfileobj(data_stream, dest) return dest.tell() def exists(self, prefix: str, oid: str) -> bool: - return os.path.isfile(self._get_path(prefix, oid)) + path = self._get_path(prefix, oid) + return path.is_file() def get_size(self, prefix: str, oid: str) -> int: if self.exists(prefix, oid): - return os.path.getsize(self._get_path(prefix, oid)) - raise exc.ObjectNotFound("Object was not found") + path = self._get_path(prefix, oid) + return path.stat().st_size + raise exc.ObjectNotFoundError("Object was not found") def get_mime_type(self, prefix: str, oid: str) -> str: if self.exists(prefix, oid): return "application/octet-stream" - raise exc.ObjectNotFound("Object was not found") + raise exc.ObjectNotFoundError("Object was not found") def get_multipart_actions( self, @@ -57,7 +63,7 @@ def get_multipart_actions( size: int, part_size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: return {} @@ -67,17 +73,18 @@ def get_download_action( oid: str, size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: return {} def register_views(self, app: Flask) -> None: super().register_views(app) - def _get_path(self, prefix: str, oid: str) -> str: - return os.path.join(self.path, prefix, oid) + def _get_path(self, prefix: str, oid: str) -> Path: + return Path(self.path) / prefix / oid @staticmethod def _create_path(path: str) -> None: - if not os.path.isdir(path): - os.makedirs(path) + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True) diff --git a/giftless/transfer/__init__.py b/giftless/transfer/__init__.py index 0bfe0e0..95726df 100644 --- a/giftless/transfer/__init__.py +++ b/giftless/transfer/__init__.py @@ -1,12 +1,13 @@ -"""Transfer adapters +"""Transfer adapters. -See https://github.com/git-lfs/git-lfs/blob/master/docs/api/basic-transfers.md +See +https://github.com/git-lfs/git-lfs/blob/master/docs/api/basic-transfers.md for more information about what transfer APIs do in Git LFS. """ from abc import ABC, abstractmethod from collections.abc import Callable from functools import partial -from typing import Any, Optional, cast +from typing import Any, cast from flask import Flask @@ -21,9 +22,13 @@ _registered_adapters: dict[str, "TransferAdapter"] = {} -class TransferAdapter(ABC): - """A transfer adapter tells Git LFS Server how to respond to batch API requests""" +class TransferAdapter(ABC): # noqa:B024 + """A transfer adapter tells Git LFS Server how to respond to batch + API requests. + """ + # We don't want these to be abstract methods because the test suite + # actually instantiates a TransferAdapter, even though it's an ABC. def upload( self, organization: str, @@ -51,14 +56,18 @@ def download( def get_action( self, name: str, organization: str, repo: str ) -> Callable[[str, int], dict]: - """Shortcut for quickly getting an action callable for transfer adapter objects""" + """Shortcut for quickly getting an action callable for + transfer adapter objects. + """ return partial( getattr(self, name), organization=organization, repo=repo ) class PreAuthorizingTransferAdapter(TransferAdapter, ABC): - """A transfer adapter that can pre-authorize one or more of the actions it supports""" + """A transfer adapter that can pre-authorize one or more of the + actions it supports. + """ @abstractmethod def __init__(self) -> None: @@ -144,13 +153,14 @@ def init_flask_app(app: Flask) -> None: def register_adapter(key: str, adapter: TransferAdapter) -> None: - """Register a transfer adapter""" + """Register a transfer adapter.""" _registered_adapters[key] = adapter def match_transfer_adapter( transfers: list[str], ) -> tuple[str, TransferAdapter]: + """Select a transfer adapter by key.""" for t in transfers: if t in _registered_adapters: return t, _registered_adapters[t] @@ -158,7 +168,7 @@ def match_transfer_adapter( def _init_adapter(config: dict) -> TransferAdapter: - """Call adapter factory to create a transfer adapter instance""" + """Call adapter factory to create a transfer adapter instance.""" factory: Callable[..., TransferAdapter] = get_callable(config["factory"]) adapter: TransferAdapter = factory(**config.get("options", {})) if isinstance(adapter, PreAuthorizingTransferAdapter): diff --git a/giftless/transfer/basic_external.py b/giftless/transfer/basic_external.py index cdeeaba..9e84c35 100644 --- a/giftless/transfer/basic_external.py +++ b/giftless/transfer/basic_external.py @@ -1,4 +1,4 @@ -"""External Backend Transfer Adapter +"""External Backend Transfer Adapter. This transfer adapter offers 'basic' transfers by directing clients to upload and download objects from an external storage service, such as AWS S3 or Azure @@ -12,7 +12,7 @@ """ import posixpath -from typing import Any, Optional +from typing import Any from flask import Flask @@ -26,7 +26,15 @@ class BasicExternalBackendTransferAdapter( PreAuthorizingTransferAdapter, ViewProvider ): - def __init__(self, storage: ExternalStorage, default_action_lifetime: int): + """Provides External Transfer Adapter. + + TODO @athornton: inherently PreAuthorizing feels weird. Investigate + whether there's refactoring/mixin work we can do here. + """ + + def __init__( + self, storage: ExternalStorage, default_action_lifetime: int + ) -> None: super().__init__() self.storage = storage self.action_lifetime = default_action_lifetime @@ -37,7 +45,7 @@ def upload( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict: prefix = posixpath.join(organization, repo) response = {"oid": oid, "size": size} @@ -51,7 +59,7 @@ def upload( prefix, oid, size, self.action_lifetime, extra ) ) - if response.get("actions", {}).get("upload"): # type: ignore + if response.get("actions", {}).get("upload"): response["authenticated"] = True headers = self._preauth_headers( organization, @@ -60,7 +68,7 @@ def upload( oid=oid, lifetime=self.VERIFY_LIFETIME, ) - response["actions"]["verify"] = { # type: ignore + response["actions"]["verify"] = { "href": VerifyView.get_verify_url(organization, repo), "header": headers, "expires_in": self.VERIFY_LIFETIME, @@ -74,7 +82,7 @@ def download( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict: prefix = posixpath.join(organization, repo) response = {"oid": oid, "size": size} @@ -89,7 +97,7 @@ def download( except exc.StorageError as e: response["error"] = e.as_dict() - if response.get("actions", {}).get("download"): # type: ignore + if response.get("actions", {}).get("download"): response["authenticated"] = True return response @@ -98,19 +106,19 @@ def register_views(self, app: Flask) -> None: VerifyView.register(app, init_argument=self.storage) def _check_object(self, prefix: str, oid: str, size: int) -> None: - """Raise specific domain error if object is not valid + """Raise specific domain error if object is not valid. NOTE: this does not use storage.verify_object directly because - we want ObjectNotFound errors to be propagated if raised + we want ObjectNotFoundError errors to be propagated if raised """ if self.storage.get_size(prefix, oid) != size: - raise exc.InvalidObject("Object size does not match") + raise exc.InvalidObjectError("Object size does not match") def factory( storage_class: Any, storage_options: Any, action_lifetime: int ) -> BasicExternalBackendTransferAdapter: - """Factory for basic transfer adapter with external storage""" + """Build a basic transfer adapter with external storage.""" storage = get_callable(storage_class, __name__) return BasicExternalBackendTransferAdapter( storage(**storage_options), action_lifetime diff --git a/giftless/transfer/basic_streaming.py b/giftless/transfer/basic_streaming.py index 82ce3bf..d20f09d 100644 --- a/giftless/transfer/basic_streaming.py +++ b/giftless/transfer/basic_streaming.py @@ -1,13 +1,13 @@ -"""Basic Streaming Transfer Adapter +"""Basic Streaming Transfer Adapter. -This transfer adapter offers 'basic' transfers by streaming uploads / downloads -through the Git LFS HTTP server. It can use different storage backends (local, -cloud, ...). This module defines an -interface through which additional streaming backends can be implemented. +This transfer adapter offers 'basic' transfers by streaming uploads / +downloads through the Git LFS HTTP server. It can use different +storage backends (local, cloud, ...). This module defines an interface +through which additional streaming backends can be implemented. """ import posixpath -from typing import Any, BinaryIO, Optional, cast +from typing import Any, BinaryIO, cast import marshmallow from flask import Flask, Response, request, url_for @@ -24,15 +24,18 @@ class VerifyView(BaseView): - """Verify an object + """Verify an object. This view is actually not basic_streaming specific, and is used by other transfer adapters that need a 'verify' action as well. + + TODO @athornton: then how about we make it a mixin, which will + make the test structures a little less weird? """ route_base = "//objects/storage" - def __init__(self, storage: VerifiableStorage): + def __init__(self, storage: VerifiableStorage) -> None: self.storage = storage @route("/verify", methods=["POST"]) @@ -55,9 +58,9 @@ def verify(self, organization: str, repo: str) -> Response: @classmethod def get_verify_url( - cls, organization: str, repo: str, oid: Optional[str] = None + cls, organization: str, repo: str, oid: str | None = None ) -> str: - """Get the URL for upload / download requests for this object""" + """Get the URL for upload / download requests for this object.""" op_name = f"{cls.__name__}:verify" url: str = url_for( op_name, @@ -70,18 +73,23 @@ def get_verify_url( class ObjectsView(BaseView): + """Provides methods for object storage management.""" + route_base = "//objects/storage" - def __init__(self, storage: StreamingStorage): + def __init__(self, storage: StreamingStorage) -> None: self.storage = storage def put(self, organization: str, repo: str, oid: str) -> Response: - """Upload a file to local storage - - For now, I am not sure this actually streams chunked uploads without reading the entire - content into memory. It seems that in order to support this, we will need to dive deeper - into the WSGI Server -> Werkzeug -> Flask stack, and it may also depend on specific WSGI - server implementation and even how a proxy (e.g. nginx) is configured. + """Upload a file to local storage. + + TODO @rufuspollock: For now, I am not sure this actually + streams chunked uploads without reading the entire content + into memory. It seems that in order to support this, we will + need to dive deeper into the WSGI Server -> Werkzeug -> Flask + stack, and it may also depend on specific WSGI server + implementation and even how a proxy (e.g. nginx) is + configured. """ self._check_authorization( organization, repo, Permission.WRITE, oid=oid @@ -95,7 +103,7 @@ def put(self, organization: str, repo: str, oid: str) -> Response: return Response(status=200) def get(self, organization: str, repo: str, oid: str) -> Response: - """Get an file open file stream from local storage""" + """Get an open file stream from local storage.""" self._check_authorization(organization, repo, Permission.READ, oid=oid) path = posixpath.join(organization, repo) @@ -127,9 +135,9 @@ def get_storage_url( operation: str, organization: str, repo: str, - oid: Optional[str] = None, + oid: str | None = None, ) -> str: - """Get the URL for upload / download requests for this object""" + """Get the URL for upload / download requests for this object.""" op_name = f"{cls.__name__}:{operation}" url: str = url_for( op_name, @@ -144,7 +152,11 @@ def get_storage_url( class BasicStreamingTransferAdapter( PreAuthorizingTransferAdapter, ViewProvider ): - def __init__(self, storage: StreamingStorage, action_lifetime: int): + """Provides Streaming Transfers.""" + + def __init__( + self, storage: StreamingStorage, action_lifetime: int + ) -> None: super().__init__() self.storage = storage self.action_lifetime = action_lifetime @@ -155,7 +167,7 @@ def upload( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict: response = {"oid": oid, "size": size} @@ -196,7 +208,7 @@ def download( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None, ) -> dict: response = {"oid": oid, "size": size} @@ -244,7 +256,7 @@ def register_views(self, app: Flask) -> None: def factory( storage_class: Any, storage_options: Any, action_lifetime: int ) -> BasicStreamingTransferAdapter: - """Factory for basic transfer adapter with local storage""" + """Build a basic transfer adapter with local storage.""" storage = get_callable(storage_class, __name__) return BasicStreamingTransferAdapter( storage(**storage_options), action_lifetime diff --git a/giftless/transfer/multipart.py b/giftless/transfer/multipart.py index a81e983..2fd18a9 100644 --- a/giftless/transfer/multipart.py +++ b/giftless/transfer/multipart.py @@ -1,8 +1,7 @@ -"""Multipart Transfer Adapter -""" +"""Multipart Transfer Adapter.""" import posixpath -from typing import Any, Optional +from typing import Any from flask import Flask @@ -12,17 +11,19 @@ from giftless.util import get_callable from giftless.view import ViewProvider -DEFAULT_PART_SIZE = 10240000 # 10mb -DEFAULT_ACTION_LIFETIME = 6 * 3600 # 6 hours +DEFAULT_PART_SIZE = 10240000 # 10MB (-ish) +DEFAULT_ACTION_LIFETIME = 6 * 60 * 60 # 6 hours class MultipartTransferAdapter(PreAuthorizingTransferAdapter, ViewProvider): + """Transfer Adapter supporting multipart methods.""" + def __init__( self, storage: MultipartStorage, default_action_lifetime: int, max_part_size: int = DEFAULT_PART_SIZE, - ): + ) -> None: super().__init__() self.storage = storage self.max_part_size = max_part_size @@ -34,7 +35,7 @@ def upload( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict: prefix = posixpath.join(organization, repo) response = {"oid": oid, "size": size} @@ -56,7 +57,7 @@ def upload( oid=oid, lifetime=self.VERIFY_LIFETIME, ) - response["actions"]["verify"] = { # type: ignore + response["actions"]["verify"] = { "href": VerifyView.get_verify_url(organization, repo), "header": headers, "expires_in": self.VERIFY_LIFETIME, @@ -70,7 +71,7 @@ def download( repo: str, oid: str, size: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict: prefix = posixpath.join(organization, repo) response = {"oid": oid, "size": size} @@ -85,25 +86,28 @@ def download( except exc.StorageError as e: response["error"] = e.as_dict() - if response.get("actions", {}).get("download"): # type: ignore + if response.get("actions", {}).get("download"): response["authenticated"] = True return response def register_views(self, app: Flask) -> None: - # FIXME: this is broken. Need to find a smarter way for multiple transfer adapters to provide the same view - # VerifyView.register(app, init_argument=self.storage) + # TODO @rufuspollock: this is broken. Need to find a smarter + # way for multiple transfer adapters to provide the same view + # -- broken: VerifyView.register(app, init_argument=self.storage) + # TODO @athornton: does this maybe indicate a classvar shadowing or + # updating issue? Investigate that. if isinstance(self.storage, ViewProvider): self.storage.register_views(app) def _check_object(self, prefix: str, oid: str, size: int) -> None: - """Raise specific domain error if object is not valid + """Raise specific domain error if object is not valid. NOTE: this does not use storage.verify_object directly because - we want ObjectNotFound errors to be propagated if raised + we want ObjectNotFoundError errors to be propagated if raised. """ if self.storage.get_size(prefix, oid) != size: - raise exc.InvalidObject("Object size does not match") + raise exc.InvalidObjectError("Object size does not match") def factory( @@ -112,11 +116,13 @@ def factory( action_lifetime: int = DEFAULT_ACTION_LIFETIME, max_part_size: int = DEFAULT_PART_SIZE, ) -> MultipartTransferAdapter: - """Factory for multipart transfer adapter with storage""" + """Build a multipart transfer adapter with storage.""" try: storage = get_callable(storage_class, __name__) except (AttributeError, ImportError): - raise ValueError(f"Unable to load storage module: {storage_class}") + raise ValueError( + f"Unable to load storage module: {storage_class}" + ) from None return MultipartTransferAdapter( storage(**storage_options), action_lifetime, diff --git a/giftless/transfer/types.py b/giftless/transfer/types.py index 151f7be..8dc8bb4 100644 --- a/giftless/transfer/types.py +++ b/giftless/transfer/types.py @@ -1,31 +1,30 @@ -"""Some useful type definitions for transfer protocols -""" -import sys -from typing import Any - -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict +"""Some useful type definitions for transfer protocols.""" +from typing import Any, TypedDict class ObjectAttributes(TypedDict): - """Type for object attributes sent in batch request""" + """Type for object attributes sent in batch request.""" oid: str size: int class BasicUploadActions(TypedDict, total=False): + """Fundamental actions for upload.""" + upload: dict[str, Any] verify: dict[str, Any] class UploadObjectAttributes(ObjectAttributes, total=False): + """Convert BasicUploadActions to object attributes.""" + actions: BasicUploadActions class MultipartUploadActions(TypedDict, total=False): + """Additional actions to support multipart uploads.""" + init: dict[str, Any] commit: dict[str, Any] parts: list[dict[str, Any]] @@ -34,4 +33,6 @@ class MultipartUploadActions(TypedDict, total=False): class MultipartUploadObjectAttributes(ObjectAttributes, total=False): + """Convert MultipartUploadActions to object attributes.""" + actions: MultipartUploadActions diff --git a/giftless/util.py b/giftless/util.py index bc50cbb..0ca7c37 100644 --- a/giftless/util.py +++ b/giftless/util.py @@ -1,16 +1,15 @@ -"""Miscellanea -""" +"""Miscellanea.""" import importlib from collections.abc import Callable, Iterable -from typing import Any, Optional +from typing import Any from urllib.parse import urlencode def get_callable( - callable_str: str, base_package: Optional[str] = None + callable_str: str, base_package: str | None = None ) -> Callable: """Get a callable function / class constructor from a string of the form - `package.subpackage.module:callable` + `package.subpackage.module:callable`. >>> type(get_callable('os.path:basename')).__name__ 'function' @@ -29,11 +28,11 @@ def get_callable( "Expecting base_package to be set if only class name is provided" ) - return getattr(module, callable_name) # type: ignore + return getattr(module, callable_name) def to_iterable(val: Any) -> Iterable: - """Get something we can iterate over from an unknown type + """Get something we can iterate over from an unknown type. >>> i = to_iterable([1, 2, 3]) >>> next(iter(i)) @@ -55,7 +54,7 @@ def to_iterable(val: Any) -> Iterable: >>> next(iter(i)) 1 """ - if isinstance(val, Iterable) and not isinstance(val, (str, bytes)): + if isinstance(val, Iterable) and not isinstance(val, str | bytes): return val return (val,) @@ -64,10 +63,12 @@ def add_query_params(url: str, params: dict[str, Any]) -> str: """Safely add query params to a url that may or may not already contain query params. - >>> add_query_params('https://example.org', {'param1': 'value1', 'param2': 'value2'}) + >>> add_query_params( + 'https://example.org', {'param1': 'value1', 'param2': 'value2'}) 'https://example.org?param1=value1¶m2=value2' - >>> add_query_params('https://example.org?param1=value1', {'param2': 'value2'}) + >>> add_query_params( + 'https://example.org?param1=value1', {'param2': 'value2'}) 'https://example.org?param1=value1¶m2=value2' """ # noqa: E501 urlencoded_params = urlencode(params) @@ -76,7 +77,7 @@ def add_query_params(url: str, params: dict[str, Any]) -> str: def safe_filename(original_filename: str) -> str: - """Returns a filename safe to use in HTTP headers, formed from the + """Return a filename safe to use in HTTP headers, formed from the given original filename. >>> safe_filename("example(1).txt") diff --git a/giftless/view.py b/giftless/view.py index 70989a2..8472c46 100644 --- a/giftless/view.py +++ b/giftless/view.py @@ -1,8 +1,7 @@ -"""Flask-Classful View Classes -""" -from typing import Any +"""Flask-Classful View Classes.""" +from typing import Any, ClassVar -from flask import Flask, Response +from flask import Flask from flask_classful import FlaskView from webargs.flaskparser import parser @@ -12,13 +11,13 @@ class BaseView(FlaskView): - """This extends on Flask-Classful's base view class to add some common custom - functionality + """Extends Flask-Classful's base view class to add some common + custom functionality. """ - decorators = [authn.login_required] + decorators: ClassVar = [authn.login_required] - representations = { + representations: ClassVar = { "application/json": representation.output_json, representation.GIT_LFS_MIME_TYPE: representation.output_git_lfs_json, "flask-classful/default": representation.output_git_lfs_json, @@ -40,10 +39,12 @@ def _check_authorization( permission: Permission, oid: str | None = None, ) -> None: - """Check the current user is authorized to perform an action and raise an exception otherwise""" + """Check the current user is authorized to perform an action + and raise an exception otherwise. + """ if not cls._is_authorized(organization, repo, permission, oid): raise exc.Forbidden( - "Your are not authorized to perform this action" + "You are not authorized to perform this action" ) @staticmethod @@ -53,7 +54,7 @@ def _is_authorized( permission: Permission, oid: str | None = None, ) -> bool: - """Check the current user is authorized to perform an action""" + """Check the current user is authorized to perform an action.""" identity = authn.get_identity() return identity is not None and identity.is_authorized( organization, repo, permission, oid @@ -61,12 +62,12 @@ def _is_authorized( class BatchView(BaseView): - """Batch operations""" + """Batch operations.""" route_base = "//objects/batch" def post(self, organization: str, repo: str) -> dict[str, Any]: - """Batch operations""" + """Batch operations.""" payload = parser.parse(schema.batch_request_schema) try: @@ -74,7 +75,7 @@ def post(self, organization: str, repo: str) -> dict[str, Any]: payload["transfers"] ) except ValueError as e: - raise exc.InvalidPayload(str(e)) + raise exc.InvalidPayload(str(e)) from None permission = ( Permission.WRITE @@ -84,7 +85,8 @@ def post(self, organization: str, repo: str) -> dict[str, Any]: try: self._check_authorization(organization, repo, permission) except exc.Forbidden: - # User doesn't have global permission to the entire namespace, but may be authorized for all objects + # User doesn't have global permission to the entire namespace, + # but may be authorized for all objects if not all( self._is_authorized(organization, repo, permission, o["oid"]) for o in payload["objects"] @@ -105,8 +107,9 @@ def post(self, organization: str, repo: str) -> dict[str, Any]: "Cannot validate any of the requested objects" ) - # TODO: Check Accept header - # TODO: do we need an output schema? + # TODO @rufuspollock: Check Accept header + # TODO @athornton: do we need an output schema? If so...should + # we just turn this into a Pydantic app? return response @@ -119,10 +122,12 @@ def _is_error(obj: dict[str, Any], code: int | None = None) -> bool: class ViewProvider: - """ViewProvider is a marker interface for storage and transfer adapters that can provide their own Flask views + """ViewProvider is a marker interface for storage and transfer + adapters that can provide their own Flask views. - This allows transfer and storage backends to register routes for accessing or verifying files, for example, - directly from the Giftless HTTP server. + This allows transfer and storage backends to register routes for + accessing or verifying files, for example, directly from the + Giftless HTTP server. """ def register_views(self, app: Flask) -> None: diff --git a/giftless/wsgi_entrypoint.py b/giftless/wsgi_entrypoint.py index a1604cf..84da354 100644 --- a/giftless/wsgi_entrypoint.py +++ b/giftless/wsgi_entrypoint.py @@ -1,6 +1,6 @@ -"""Entry point module for WSGI +"""Entry point module for WSGI. -This is used when running the app using a WSGI server such as uWSGI +This is used when running the app using a WSGI server such as uWSGI. """ from .app import init_app diff --git a/pyproject.toml b/pyproject.toml index 8dd4e10..04cdf24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,21 +53,6 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] -[tool.black] -line-length = 79 -target-version = ["py310"] -exclude = ''' -/( - \.eggs - | \.git - | \.mypy_cache - | \.tox - | \.venv - | _build - | build - | dist -)/ -''' # Use single-quoted strings so TOML treats the string like a Python r-string # Multi-line strings are implicitly treated by black as regular expressions @@ -165,7 +150,7 @@ python_files = [ # Reference for rules: https://docs.astral.sh/ruff/rules/ [tool.ruff] exclude = [ - "docs/conf.py", + "docs/source/conf.py", ] line-length = 79 ignore = [ @@ -185,6 +170,8 @@ ignore = [ "D205", # our documentation style allows a folded first line "EM101", # justification (duplicate string in traceback) is silly "EM102", # justification (duplicate string in traceback) is silly + "FBT001", # positional booleans are normal for Pydantic field defaults + "FBT002", # positional booleans are normal for Pydantic field defaults "FBT003", # positional booleans are normal for Pydantic field defaults "FIX002", # point of a TODO comment is that we're not ready to fix it "G004", # forbidding logging f-strings is appealing, but not our style diff --git a/tests/auth/test_auth.py b/tests/auth/test_auth.py index 2f8501e..1b81b2a 100644 --- a/tests/auth/test_auth.py +++ b/tests/auth/test_auth.py @@ -1,5 +1,4 @@ -"""Unit tests for auth module -""" +"""Unit tests for auth module.""" from typing import Any import pytest @@ -9,7 +8,7 @@ def test_default_identity_properties() -> None: - """Test the basic properties of the default identity object""" + """Test the basic properties of the default identity object.""" user = DefaultIdentity( "arthur", "kingofthebritons", "arthur@camelot.gov.uk" ) @@ -60,7 +59,7 @@ def test_default_identity_denied_by_default(requested: dict[str, Any]) -> None: @pytest.mark.parametrize( - "requested, expected", + ("requested", "expected"), [ ( { @@ -78,14 +77,6 @@ def test_default_identity_denied_by_default(requested: dict[str, Any]) -> None: }, False, ), - ( - { - "permission": Permission.READ, - "organization": "myorg", - "repo": "somerepo", - }, - True, - ), ( { "permission": Permission.READ, @@ -118,7 +109,7 @@ def test_default_identity_allow_specific_repo( @pytest.mark.parametrize( - "requested, expected", + ("requested", "expected"), [ ( { @@ -184,7 +175,7 @@ def test_default_identity_allow_specific_org_permissions( @pytest.mark.parametrize( - "requested, expected", + ("requested", "expected"), [ ( { @@ -231,13 +222,13 @@ def test_default_identity_allow_specific_org_permissions( def test_allow_anon_read_only( requested: dict[str, Any], expected: bool ) -> None: - """Test that an anon user with read only permissions works as expected""" + """Test that an anon user with read only permissions works as expected.""" user = allow_anon.read_only(None) assert expected is user.is_authorized(**requested) @pytest.mark.parametrize( - "requested, expected", + ("requested", "expected"), [ ( { @@ -284,13 +275,13 @@ def test_allow_anon_read_only( def test_allow_anon_read_write( requested: dict[str, Any], expected: bool ) -> None: - """Test that an anon user with read only permissions works as expected""" + """Test that an anon user with read only permissions works as expected.""" user = allow_anon.read_write(None) assert expected is user.is_authorized(**requested) def test_anon_user_interface() -> None: - """Test that an anonymous user has the right interface""" + """Test that an anonymous user has the right interface.""" user = allow_anon.read_only(None) assert isinstance(user, allow_anon.AnonymousUser) assert user.name == "anonymous" diff --git a/tests/auth/test_jwt.py b/tests/auth/test_jwt.py index 773f1c7..754d8e6 100644 --- a/tests/auth/test_jwt.py +++ b/tests/auth/test_jwt.py @@ -1,12 +1,12 @@ +"""Tests for JWT authorization.""" import base64 -import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone +from pathlib import Path from typing import Any import flask import jwt import pytest -import pytz from giftless.auth import Unauthorized from giftless.auth.identity import DefaultIdentity, Permission @@ -16,16 +16,12 @@ JWT_HS_KEY = b"some-random-secret" # Asymmetric key files used in tests -JWT_RS_PRI_KEY = os.path.join( - os.path.dirname(__file__), "data", "test-key.pem" -) -JWT_RS_PUB_KEY = os.path.join( - os.path.dirname(__file__), "data", "test-key.pub.pem" -) +JWT_RS_PRI_KEY = Path(__file__).parent / "data" / "test-key.pem" +JWT_RS_PUB_KEY = Path(__file__).parent / "data" / "test-key.pub.pem" def test_jwt_can_authorize_request_symmetric_key(app: flask.Flask) -> None: - """Test basic JWT authorizer functionality""" + """Test basic JWT authorizer functionality: HS256 symmetric.""" authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token() with app.test_request_context( @@ -39,7 +35,7 @@ def test_jwt_can_authorize_request_symmetric_key(app: flask.Flask) -> None: def test_jwt_can_authorize_request_asymmetric_key(app: flask.Flask) -> None: - """Test basic JWT authorizer functionality""" + """Test basic JWT authorizer functionality: RS256 asymmetric.""" authz = factory(public_key_file=JWT_RS_PUB_KEY, algorithm="RS256") token = _get_test_token(algo="RS256") with app.test_request_context( @@ -53,7 +49,7 @@ def test_jwt_can_authorize_request_asymmetric_key(app: flask.Flask) -> None: def test_jwt_can_authorize_request_token_in_qs(app: flask.Flask) -> None: - """Test basic JWT authorizer functionality""" + """Test basic JWT authorizer functionality with query param.""" authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token() with app.test_request_context( @@ -67,7 +63,7 @@ def test_jwt_can_authorize_request_token_in_qs(app: flask.Flask) -> None: def test_jwt_can_authorize_request_token_as_basic_password( app: flask.Flask, ) -> None: - """Test that we can pass a JWT token as 'Basic' authorization password""" + """Test that we can pass a JWT token as 'Basic' authorization password.""" authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token() auth_value = base64.b64encode( @@ -87,7 +83,9 @@ def test_jwt_can_authorize_request_token_as_basic_password( def test_jwt_can_authorize_request_token_basic_password_disabled( app: flask.Flask, ) -> None: - """Test that we can pass a JWT token as 'Basic' authorization password""" + """Test that we can pass a JWT token as 'Basic' authorization password + when user is None. + """ authz = JWTAuthenticator( private_key=JWT_HS_KEY, algorithm="HS256", basic_auth_user=None ) @@ -106,7 +104,9 @@ def test_jwt_can_authorize_request_token_basic_password_disabled( def test_jwt_with_wrong_kid_doesnt_authorize_request(app: flask.Flask) -> None: - """JWT authorizer only considers a JWT token if it has the right key ID in the header""" + """JWT authorizer only considers a JWT token if it has the right key ID + in the header. + """ authz = JWTAuthenticator( private_key=JWT_HS_KEY, algorithm="HS256", key_id="must-be-this-key" ) @@ -121,7 +121,7 @@ def test_jwt_with_wrong_kid_doesnt_authorize_request(app: flask.Flask) -> None: def test_jwt_expired_throws_401(app: flask.Flask) -> None: - """If we get a JWT token who's expired, we should raise a 401 error""" + """If we get a JWT token that has expired, we should raise a 401 error.""" authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token(lifetime=-600) # expired 10 minutes ago with app.test_request_context( @@ -154,7 +154,10 @@ def test_jwt_pre_authorize_action() -> None: # Check that now() - expiration time is within 5 seconds of 120 seconds assert ( abs( - (datetime.fromtimestamp(payload["exp"]) - datetime.now()).seconds + ( + datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + - datetime.now(tz=timezone.utc) + ).seconds - 120 ) < 5 @@ -182,7 +185,10 @@ def test_jwt_pre_authorize_action_custom_lifetime() -> None: # Check that now() - expiration time is within 5 seconds of 3600 seconds assert ( abs( - (datetime.fromtimestamp(payload["exp"]) - datetime.now()).seconds + ( + datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + - datetime.now(tz=timezone.utc) + ).seconds - 3600 ) < 5 @@ -190,7 +196,7 @@ def test_jwt_pre_authorize_action_custom_lifetime() -> None: @pytest.mark.parametrize( - "scopes, auth_check, expected", + ("scopes", "auth_check", "expected"), [ ( [], @@ -237,15 +243,6 @@ def test_jwt_pre_authorize_action_custom_lifetime() -> None: }, False, ), - ( - ["obj:myorg/myrepo/*"], - { - "organization": "myorg", - "repo": "myrepo", - "permission": Permission.READ, - }, - True, - ), ( ["obj:myorg/myrepo/*:read"], { @@ -407,7 +404,7 @@ def test_jwt_pre_authorize_action_custom_lifetime() -> None: def test_jwt_scopes_authorize_actions( app: flask.Flask, scopes: str, auth_check: dict[str, Any], expected: bool ) -> None: - """Test that JWT token scopes can control authorization""" + """Test that JWT token scopes can control authorization.""" authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token(scopes=scopes) with app.test_request_context( @@ -422,7 +419,9 @@ def test_jwt_scopes_authorize_actions( def test_jwt_scopes_authorize_actions_with_anon_user(app: flask.Flask) -> None: - """Test that authorization works even if we don't have any user ID / email / name""" + """Test that authorization works even if we don't have any user ID + / email / name. + """ scopes = ["obj:myorg/myrepo/*"] authz = JWTAuthenticator(private_key=JWT_HS_KEY, algorithm="HS256") token = _get_test_token(scopes=scopes, sub=None, name=None, email=None) @@ -440,7 +439,7 @@ def test_jwt_scopes_authorize_actions_with_anon_user(app: flask.Flask) -> None: @pytest.mark.parametrize( - "scope_str, expected", + ("scope_str", "expected"), [ ( "org:myorg:*", @@ -535,14 +534,14 @@ def test_jwt_scopes_authorize_actions_with_anon_user(app: flask.Flask) -> None: ], ) def test_scope_parsing(scope_str: str, expected: dict[str, Any]) -> None: - """Test scope string parsing works as expected""" + """Test scope string parsing works as expected.""" scope = Scope.from_string(scope_str) for k, v in expected.items(): assert getattr(scope, k) == v @pytest.mark.parametrize( - "scope, expected", + ("scope", "expected"), [ (Scope("org", "myorg"), "org:myorg"), (Scope("org", "myorg", subscope="meta"), "org:myorg:meta:*"), @@ -557,7 +556,7 @@ def test_scope_parsing(scope_str: str, expected: dict[str, Any]) -> None: ], ) def test_scope_stringify(scope: Scope, expected: str) -> None: - """Test scope stringification works as expected""" + """Test scope stringification works as expected.""" assert expected == str(scope) @@ -568,7 +567,7 @@ def _get_test_token( **kwargs: Any, ) -> str: payload = { - "exp": datetime.now(tz=pytz.utc) + timedelta(seconds=lifetime), + "exp": datetime.now(tz=timezone.utc) + timedelta(seconds=lifetime), "sub": "some-user-id", } @@ -577,7 +576,7 @@ def _get_test_token( if algo == "HS256": key = JWT_HS_KEY elif algo == "RS256": - with open(JWT_RS_PRI_KEY, "rb") as f: + with JWT_RS_PRI_KEY.open("rb") as f: key = f.read() else: raise ValueError(f"Don't know how to test algo: {algo}") diff --git a/tests/conftest.py b/tests/conftest.py index 7792037..fdfe08c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ +"""Fixtures for giftless testing.""" import pathlib import shutil -from typing import Generator, cast +from collections.abc import Generator import flask import pytest @@ -16,14 +17,14 @@ def storage_path(tmp_path: pathlib.Path) -> Generator: path = tmp_path / "lfs-tests" path.mkdir() try: - yield str(tmp_path) + yield str(path) finally: shutil.rmtree(path) @pytest.fixture def app(storage_path: str) -> flask.Flask: - """Session fixture to configure the Flask app""" + """Session fixture to configure the Flask app.""" app = init_app( additional_config={ "TESTING": True, @@ -55,13 +56,12 @@ def test_client(app_context: AppContext) -> FlaskClient: @pytest.fixture -def authz_full_access( +def _authz_full_access( app_context: AppContext, -) -> ( - Generator -): # needed to ensure we call init_authenticators before app context is destroyed - """Fixture that enables full anonymous access to all actions for tests that - use it +) -> Generator: + """Fixture that enables full anonymous access to all actions for + tests that use it. Try block needed to ensure we call + init_authenticators before app context is destroyed. """ try: authentication.push_authenticator( diff --git a/tests/helpers.py b/tests/helpers.py index 80adc4a..9352d00 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,13 +1,13 @@ -"""Test helpers -""" -import os +"""Test helpers.""" +from collections.abc import Sequence +from pathlib import Path from typing import Any def batch_request_payload( - delete_keys: list[str] = [], **kwargs: Any + delete_keys: Sequence[str] = (), **kwargs: Any ) -> dict[str, Any]: - """Generate sample batch request payload""" + """Generate sample batch request payload.""" payload = { "operation": "download", "transfers": ["basic"], @@ -25,15 +25,17 @@ def batch_request_payload( def create_file_in_storage( storage_path: str, org: str, repo: str, filename: str, size: int = 1 ) -> None: - """Put a dummy file in the storage path for a specific org / repo / oid combination + """Put a dummy file in the storage path for a specific org / repo + / oid combination. - This is useful where we want to test download / verify actions without relying on - 'put' actions to work + This is useful where we want to test download / verify actions + without relying on 'put' actions to work. - This assumes cleanup is done somewhere else (e.g. in the 'storage_path' fixture) + This assumes cleanup is done somewhere else (e.g. in the + 'storage_path' fixture). """ - repo_path = os.path.join(storage_path, org, repo) - os.makedirs(repo_path, exist_ok=True) - with open(os.path.join(repo_path, filename), "wb") as f: + repo_path = Path(storage_path) / org / repo + repo_path.mkdir(parents=True, exist_ok=True) + with Path(repo_path / filename).open("wb") as f: for c in (b"0" for _ in range(size)): f.write(c) diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py index 2c1f8bf..08c3b12 100644 --- a/tests/storage/__init__.py +++ b/tests/storage/__init__.py @@ -4,7 +4,7 @@ import pytest from giftless.storage import ExternalStorage, StreamingStorage -from giftless.storage.exc import ObjectNotFound +from giftless.storage.exc import ObjectNotFoundError ARBITRARY_OID = ( "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" @@ -15,13 +15,16 @@ class _CommonStorageAbstractTests: - """Common tests for all storage backend types and interfaces + """Common tests for all storage backend types and interfaces. - This should not be used directly, because it is inherited by other AbstractTest test suites. + This should not be used directly, because it is inherited by other + AbstractTest test suites. + + Perhaps that means that we should make this an ABC? """ def test_get_size(self, storage_backend: StreamingStorage) -> None: - """Test getting the size of a stored object""" + """Test getting the size of a stored object.""" content = b"The contents of a file-like object" storage_backend.put("org/repo", ARBITRARY_OID, io.BytesIO(content)) assert len(content) == storage_backend.get_size( @@ -31,12 +34,14 @@ def test_get_size(self, storage_backend: StreamingStorage) -> None: def test_get_size_not_existing( self, storage_backend: StreamingStorage ) -> None: - """Test getting the size of a non-existing object raises an exception""" - with pytest.raises(ObjectNotFound): + """Test getting the size of a non-existing object raises an + exception. + """ + with pytest.raises(ObjectNotFoundError): storage_backend.get_size("org/repo", ARBITRARY_OID) def test_exists_exists(self, storage_backend: StreamingStorage) -> None: - """Test that calling exists on an existing object returns True""" + """Test that calling exists on an existing object returns True.""" content = b"The contents of a file-like object" storage_backend.put("org/repo", ARBITRARY_OID, io.BytesIO(content)) assert storage_backend.exists("org/repo", ARBITRARY_OID) @@ -44,14 +49,18 @@ def test_exists_exists(self, storage_backend: StreamingStorage) -> None: def test_exists_not_exists( self, storage_backend: StreamingStorage ) -> None: - """Test that calling exists on a non-existing object returns False""" + """Test that calling exists on a non-existing object returns False.""" assert not storage_backend.exists("org/repo", ARBITRARY_OID) class _VerifiableStorageAbstractTests: - """Mixin class for other base storage adapter test classes that implement VerifyableStorage + """Mixin class for other base storage adapter test classes that implement + VerifiableStorage. + + This should not be used directly, because it is inherited by other + AbstractTest test suites. - This should not be used directly, because it is inherited by other AbstractTest test suites. + Perhaps that means this should be an ABC? """ def test_verify_object_ok(self, storage_backend: StreamingStorage) -> None: @@ -81,14 +90,18 @@ def test_verify_object_not_there( class StreamingStorageAbstractTests( _CommonStorageAbstractTests, _VerifiableStorageAbstractTests ): - """Mixin for testing the StreamingStorage methods of a backend that implements StreamingStorage + """Mixin for testing the StreamingStorage methods of a backend + that implements StreamingStorage. + + To use, create a concrete test class mixing this class in, and + define a fixture named ``storage_backend`` that returns an + appropriate storage backend object. - To use, create a concrete test class mixing this class in, and define a fixture named - ``storage_backend`` that returns an appropriate storage backend object. + Again, perhaps this should be defined as an ABC? """ def test_put_get_object(self, storage_backend: StreamingStorage) -> None: - """Test a full put-then-get cycle""" + """Test a full put-then-get cycle.""" content = b"The contents of a file-like object" written = storage_backend.put( "org/repo", ARBITRARY_OID, io.BytesIO(content) @@ -103,18 +116,23 @@ def test_put_get_object(self, storage_backend: StreamingStorage) -> None: def test_get_raises_if_not_found( self, storage_backend: StreamingStorage ) -> None: - """Test that calling get for a non-existing object raises""" - with pytest.raises(ObjectNotFound): + """Test that calling get for a non-existing object raises.""" + with pytest.raises(ObjectNotFoundError): storage_backend.get("org/repo", ARBITRARY_OID) class ExternalStorageAbstractTests( _CommonStorageAbstractTests, _VerifiableStorageAbstractTests ): - """Mixin for testing the ExternalStorage methods of a backend that implements ExternalStorage + """Mixin for testing the ExternalStorage methods of a backend that + implements ExternalStorage. + + To use, create a concrete test class mixing this class in, and + define a fixture named ``storage_backend`` that returns an + appropriate storage backend object. + - To use, create a concrete test class mixing this class in, and define a fixture named - ``storage_backend`` that returns an appropriate storage backend object. + Again, perhaps this should be defined as an ABC? """ def test_get_upload_action( diff --git a/tests/storage/test_amazon_s3.py b/tests/storage/test_amazon_s3.py index db71749..78ed13f 100644 --- a/tests/storage/test_amazon_s3.py +++ b/tests/storage/test_amazon_s3.py @@ -1,5 +1,4 @@ -"""Tests for the Azure storage backend -""" +"""Tests for the Azure storage backend.""" import os from base64 import b64decode from binascii import unhexlify @@ -22,14 +21,15 @@ @pytest.fixture def storage_backend() -> Generator[AmazonS3Storage, None, None]: - """Provide a S3 Storage backend for all AWS S3 tests + """Provide a S3 Storage backend for all AWS S3 tests. For this to work against production S3, you need to set boto3 auth: 1. AWS_ACCESS_KEY_ID 2. AWS_SECRET_ACCESS_KEY For more details please see: - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#environment-variables + https://boto3.amazonaws.com/v1/documentation/api/latest/ + guide/credentials.html#environment-variables If these variables are not set, and pytest-vcr is not in use, the tests *will* fail. @@ -47,7 +47,9 @@ def storage_backend() -> Generator[AmazonS3Storage, None, None]: try: bucket.objects.all().delete() except Exception as e: - raise pytest.PytestWarning(f"Could not clean up after test: {e}") + raise pytest.PytestWarning( + f"Could not clean up after test: {e}" + ) from None @pytest.fixture(scope="module") diff --git a/tests/storage/test_azure.py b/tests/storage/test_azure.py index fcca9a8..d365c0a 100644 --- a/tests/storage/test_azure.py +++ b/tests/storage/test_azure.py @@ -1,5 +1,4 @@ -"""Tests for the Azure storage backend -""" +"""Tests for the Azure storage backend.""" import os from collections.abc import Generator from typing import Any @@ -18,12 +17,14 @@ @pytest.fixture def storage_backend() -> Generator[AzureBlobsStorage, None, None]: - """Provide an Azure Blob Storage backend for all Azure tests + """Provide an Azure Blob Storage backend for all Azure tests. - For this to work against production Azure, you need to set ``AZURE_CONNECTION_STRING`` - and ``AZURE_CONTAINER`` environment variables when running the tests. + For this to work against production Azure, you need to set + ``AZURE_CONNECTION_STRING`` and ``AZURE_CONTAINER`` environment + variables when running the tests. - If these variables are not set, and pytest-vcr is not in use, the tests *will* fail. + If these variables are not set, and pytest-vcr is not in use, the + tests *will* fail. """ connection_str = os.environ.get("AZURE_CONNECTION_STRING") container_name = os.environ.get("AZURE_CONTAINER") @@ -61,10 +62,7 @@ def vcr_config() -> dict[str, Any]: os.environ.get("AZURE_CONNECTION_STRING") and os.environ.get("AZURE_CONTAINER") ) - if live_tests: - mode = "once" - else: - mode = "none" + mode = "once" if live_tests else "none" return { "filter_headers": [("authorization", "fake-authz-header")], "record_mode": mode, diff --git a/tests/storage/test_google_cloud.py b/tests/storage/test_google_cloud.py index d7d8ea9..7abad44 100644 --- a/tests/storage/test_google_cloud.py +++ b/tests/storage/test_google_cloud.py @@ -1,5 +1,4 @@ -"""Tests for the Google Cloud Storage storage backend -""" +"""Tests for the Google Cloud Storage storage backend.""" import os from collections.abc import Generator from typing import Any @@ -58,7 +57,7 @@ @pytest.fixture def storage_backend() -> Generator[GoogleCloudStorage, None, None]: - """Provide a Google Cloud Storage backend for all GCS tests + """Provide a Google Cloud Storage backend for all GCS tests. For this to work against production Google Cloud, you need to set ``GCP_ACCOUNT_KEY_FILE``, ``GCP_PROJECT_NAME`` and ``GCP_BUCKET_NAME`` @@ -90,7 +89,7 @@ def storage_backend() -> Generator[GoogleCloudStorage, None, None]: except GoogleAPIError as e: raise pytest.PytestWarning( f"Could not clean up after test: {e}" - ) + ) from None else: yield GoogleCloudStorage( project_name=MOCK_GCP_PROJECT_NAME, @@ -107,24 +106,24 @@ def vcr_config() -> dict[str, Any]: and os.environ.get("GCP_PROJECT_NAME") and os.environ.get("GCP_BUCKET_NAME") ) - if live_tests: - mode = "once" - else: - mode = "none" + mode = "once" if live_tests else "none" return { "filter_headers": [("authorization", "fake-authz-header")], "record_mode": mode, } -# FIXME: updating the storage backends has caused the VCR cassettes to -# become invalid. Datopian will need to rebuild those cassettes with data -# from the current implementation. +# TODO @athornton: updating the storage backends has caused the VCR cassettes +# to become invalid. Datopian will need to rebuild those cassettes with data +# from the current implementation, or (better) we should use something other +# than pytest-vcr, which is opaque and unhelpful. # # I can confirm that the Google Cloud Storage Backend at least works in # conjunction with Workload Identity, since I'm using that for my own storage # in my Git LFS implementation. -- AJT 20231220 # # @pytest.mark.vcr() -# class TestGoogleCloudStorageBackend(StreamingStorageAbstractTests, ExternalStorageAbstractTests): -# pass +# class TestGoogleCloudStorageBackend( +# StreamingStorageAbstractTests, ExternalStorageAbstractTests +# ): +# pass diff --git a/tests/storage/test_local.py b/tests/storage/test_local.py index 697230e..294d72a 100644 --- a/tests/storage/test_local.py +++ b/tests/storage/test_local.py @@ -1,9 +1,7 @@ -"""Tests for the local storage backend -""" -import os -import pathlib +"""Tests for the local storage backend.""" import shutil from collections.abc import Generator +from pathlib import Path import pytest @@ -13,30 +11,28 @@ @pytest.fixture -def storage_dir(tmp_path: pathlib.Path) -> Generator[pathlib.Path, None, None]: - """Create a unique temp dir for testing storage""" - dir = None +def storage_dir(tmp_path: Path) -> Generator[Path, None, None]: + """Create a unique temp dir for testing storage.""" + tdir = None try: - dir = tmp_path / "giftless_tests" - dir.mkdir(parents=True) - yield dir + tdir = tmp_path / "giftless_tests" + tdir.mkdir(parents=True) + yield tdir finally: - if dir and os.path.isdir(dir): - shutil.rmtree(dir) + if tdir and tdir.is_dir(): + shutil.rmtree(tdir) @pytest.fixture def storage_backend(storage_dir: str) -> LocalStorage: - """Provide a local storage backend for all local tests""" + """Provide a local storage backend for all local tests.""" return LocalStorage(path=storage_dir) class TestLocalStorageBackend(StreamingStorageAbstractTests): - def test_local_path_created_on_init( - self, storage_dir: pathlib.Path - ) -> None: - """Test that the local storage path is created on module init""" - storage_path = str(storage_dir / "here") - assert not os.path.exists(storage_path) - LocalStorage(path=storage_path) - assert os.path.exists(storage_path) + def test_local_path_created_on_init(self, storage_dir: Path) -> None: + """Test that the local storage path is created on module init.""" + storage_path = storage_dir / "here" + assert not storage_path.exists() + LocalStorage(path=str(storage_path)) + assert storage_path.exists() diff --git a/tests/test_batch_api.py b/tests/test_batch_api.py index cd8ef9a..ac8f9b3 100644 --- a/tests/test_batch_api.py +++ b/tests/test_batch_api.py @@ -1,5 +1,4 @@ -"""Tests for schema definitions -""" +"""Tests for schema definitions.""" from typing import cast import pytest @@ -8,9 +7,9 @@ from .helpers import batch_request_payload, create_file_in_storage -@pytest.mark.usefixtures("authz_full_access") +@pytest.mark.usefixtures("_authz_full_access") def test_upload_batch_request(test_client: FlaskClient) -> None: - """Test basic batch API with a basic successful upload request""" + """Test basic batch API with a basic successful upload request.""" request_payload = batch_request_payload(operation="upload") response = test_client.post( "/myorg/myrepo/objects/batch", json=request_payload @@ -24,18 +23,18 @@ def test_upload_batch_request(test_client: FlaskClient) -> None: assert payload["transfer"] == "basic" assert len(payload["objects"]) == 1 - object = payload["objects"][0] - assert object["oid"] == request_payload["objects"][0]["oid"] - assert object["size"] == request_payload["objects"][0]["size"] - assert len(object["actions"]) == 2 - assert "upload" in object["actions"] - assert "verify" in object["actions"] + obj = payload["objects"][0] + assert obj["oid"] == request_payload["objects"][0]["oid"] + assert obj["size"] == request_payload["objects"][0]["size"] + assert len(obj["actions"]) == 2 + assert "upload" in obj["actions"] + assert "verify" in obj["actions"] def test_download_batch_request( test_client: FlaskClient, storage_path: str ) -> None: - """Test basic batch API with a basic successful upload request""" + """Test basic batch API with a basic successful upload request.""" request_payload = batch_request_payload(operation="download") oid = request_payload["objects"][0]["oid"] create_file_in_storage(storage_path, "myorg", "myrepo", oid, size=8) @@ -52,17 +51,17 @@ def test_download_batch_request( assert payload["transfer"] == "basic" assert len(payload["objects"]) == 1 - object = payload["objects"][0] - assert object["oid"] == request_payload["objects"][0]["oid"] - assert object["size"] == request_payload["objects"][0]["size"] - assert len(object["actions"]) == 1 - assert "download" in object["actions"] + obj = payload["objects"][0] + assert obj["oid"] == request_payload["objects"][0]["oid"] + assert obj["size"] == request_payload["objects"][0]["size"] + assert len(obj["actions"]) == 1 + assert "download" in obj["actions"] def test_download_batch_request_two_files_one_missing( test_client: FlaskClient, storage_path: str ) -> None: - """Test batch API with a two object download request where one file 404""" + """Test batch API with a two object download request where one file 404.""" request_payload = batch_request_payload(operation="download") oid = request_payload["objects"][0]["oid"] create_file_in_storage(storage_path, "myorg", "myrepo", oid, size=8) @@ -82,23 +81,25 @@ def test_download_batch_request_two_files_one_missing( assert payload["transfer"] == "basic" assert len(payload["objects"]) == 2 - object = payload["objects"][0] - assert object["oid"] == request_payload["objects"][0]["oid"] - assert object["size"] == request_payload["objects"][0]["size"] - assert len(object["actions"]) == 1 - assert "download" in object["actions"] + obj = payload["objects"][0] + assert obj["oid"] == request_payload["objects"][0]["oid"] + assert obj["size"] == request_payload["objects"][0]["size"] + assert len(obj["actions"]) == 1 + assert "download" in obj["actions"] - object = payload["objects"][1] - assert object["oid"] == request_payload["objects"][1]["oid"] - assert object["size"] == request_payload["objects"][1]["size"] - assert "actions" not in object - assert object["error"]["code"] == 404 + obj = payload["objects"][1] + assert obj["oid"] == request_payload["objects"][1]["oid"] + assert obj["size"] == request_payload["objects"][1]["size"] + assert "actions" not in obj + assert obj["error"]["code"] == 404 def test_download_batch_request_two_files_missing( test_client: FlaskClient, ) -> None: - """Test batch API with a two object download request where one file 404""" + """Test batch API with a two object download request where both files + 404. + """ request_payload = batch_request_payload(operation="download") request_payload["objects"].append({"oid": "12345679", "size": 5555}) @@ -118,7 +119,7 @@ def test_download_batch_request_two_files_missing( def test_download_batch_request_two_files_one_mismatch( test_client: FlaskClient, storage_path: str ) -> None: - """Test batch API with a two object download request where one file 422""" + """Test batch API with a two object download request where one file 422.""" request_payload = batch_request_payload(operation="download") request_payload["objects"].append({"oid": "12345679", "size": 8}) @@ -149,23 +150,23 @@ def test_download_batch_request_two_files_one_mismatch( assert payload["transfer"] == "basic" assert len(payload["objects"]) == 2 - object = payload["objects"][0] - assert object["oid"] == request_payload["objects"][0]["oid"] - assert object["size"] == request_payload["objects"][0]["size"] - assert len(object["actions"]) == 1 - assert "download" in object["actions"] + obj = payload["objects"][0] + assert obj["oid"] == request_payload["objects"][0]["oid"] + assert obj["size"] == request_payload["objects"][0]["size"] + assert len(obj["actions"]) == 1 + assert "download" in obj["actions"] - object = payload["objects"][1] - assert object["oid"] == request_payload["objects"][1]["oid"] - assert object["size"] == request_payload["objects"][1]["size"] - assert "actions" not in object - assert object["error"]["code"] == 422 + obj = payload["objects"][1] + assert obj["oid"] == request_payload["objects"][1]["oid"] + assert obj["size"] == request_payload["objects"][1]["size"] + assert "actions" not in obj + assert obj["error"]["code"] == 422 def test_download_batch_request_one_file_mismatch( test_client: FlaskClient, storage_path: str ) -> None: - """Test batch API with a two object download request where one file 422""" + """Test batch API with a one object download request where the file 422.""" request_payload = batch_request_payload(operation="download") create_file_in_storage( storage_path, @@ -191,7 +192,9 @@ def test_download_batch_request_one_file_mismatch( def test_download_batch_request_two_files_different_errors( test_client: FlaskClient, storage_path: str ) -> None: - """Test batch API with a two object download request where one file is missing and one is mismatch""" + """Test batch API with a two object download request where one file is + missing and one is mismatch. + """ request_payload = batch_request_payload(operation="download") request_payload["objects"].append({"oid": "12345679", "size": 8}) create_file_in_storage( diff --git a/tests/test_error_responses.py b/tests/test_error_responses.py index ade7495..4bc1c2b 100644 --- a/tests/test_error_responses.py +++ b/tests/test_error_responses.py @@ -1,12 +1,11 @@ -"""Tests for schema definitions -""" +"""Tests for schema definitions.""" from flask.testing import FlaskClient from .helpers import batch_request_payload def test_error_response_422(test_client: FlaskClient) -> None: - """Test an invalid payload error""" + """Test an invalid payload error.""" response = test_client.post( "/myorg/myrepo/objects/batch", json=batch_request_payload(delete_keys=["operation"]), @@ -18,7 +17,7 @@ def test_error_response_422(test_client: FlaskClient) -> None: def test_error_response_404(test_client: FlaskClient) -> None: - """Test a bad route error""" + """Test a bad route error.""" response = test_client.get("/now/for/something/completely/different") assert response.status_code == 404 @@ -27,7 +26,9 @@ def test_error_response_404(test_client: FlaskClient) -> None: def test_error_response_403(test_client: FlaskClient) -> None: - """Test that we get Forbidden when trying to upload with the default read-only setup""" + """Test that we get Forbidden when trying to upload with the default + read-only setup. + """ response = test_client.post( "/myorg/myrepo/objects/batch", json=batch_request_payload(operation="upload"), diff --git a/tests/test_middleware.py b/tests/test_middleware.py index bc23aae..bd64c98 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,5 +1,4 @@ -"""Tests for using middleware and some specific middleware -""" +"""Tests for using middleware and some specific middleware.""" from typing import Any, cast import pytest @@ -13,8 +12,8 @@ @pytest.fixture def app(storage_path: str) -> Flask: - """Session fixture to configure the Flask app""" - app = init_app( + """Session fixture to configure the Flask app.""" + return init_app( additional_config={ "TESTING": True, "TRANSFER_ADAPTERS": { @@ -34,14 +33,15 @@ def app(storage_path: str) -> Flask: ], } ) - return app -@pytest.mark.usefixtures("authz_full_access") +@pytest.mark.usefixtures("_authz_full_access") def test_upload_request_with_x_forwarded_middleware( test_client: FlaskClient, ) -> None: - """Test the ProxyFix middleware generates correct URLs if X-Forwarded headers are set""" + """Test the ProxyFix middleware generates correct URLs if + X-Forwarded headers are set. + """ request_payload = batch_request_payload(operation="upload") response = test_client.post( "/myorg/myrepo/objects/batch", json=request_payload diff --git a/tests/test_schema.py b/tests/test_schema.py index 6929626..1f16520 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,5 +1,4 @@ -"""Tests for schema definitions -""" +"""Tests for schema definitions.""" import pytest from marshmallow import ValidationError @@ -10,20 +9,20 @@ @pytest.mark.parametrize( - "input", + "inp", [ (batch_request_payload()), (batch_request_payload(operation="upload")), (batch_request_payload(delete_keys=["ref", "transfers"])), ], ) -def test_batch_request_schema_valid(input: str) -> None: - parsed = schema.BatchRequest().load(input) +def test_batch_request_schema_valid(inp: str) -> None: + parsed = schema.BatchRequest().load(inp) assert parsed @pytest.mark.parametrize( - "input", + "inp", [ ({}), (batch_request_payload(operation="sneeze")), @@ -36,14 +35,14 @@ def test_batch_request_schema_valid(input: str) -> None: (batch_request_payload(objects=[{"oid": "123abc", "size": -12}])), ], ) -def test_batch_request_schema_invalid(input: str) -> None: +def test_batch_request_schema_invalid(inp: str) -> None: with pytest.raises(ValidationError): - schema.BatchRequest().load(input) + schema.BatchRequest().load(inp) def test_batch_request_default_transfer() -> None: - input = batch_request_payload(delete_keys=["transfers"]) - parsed = schema.BatchRequest().load(input) + inp = batch_request_payload(delete_keys=["transfers"]) + parsed = schema.BatchRequest().load(inp) assert ["basic"] == parsed["transfers"] diff --git a/tests/transfer/conftest.py b/tests/transfer/conftest.py index 2c49f83..805ee43 100644 --- a/tests/transfer/conftest.py +++ b/tests/transfer/conftest.py @@ -1,6 +1,5 @@ -"""Some global fixtures for transfer tests -""" -from typing import Generator +"""Some global fixtures for transfer tests.""" +from collections.abc import Generator import pytest @@ -8,8 +7,8 @@ @pytest.fixture -def reset_registered_transfers() -> Generator: - """Reset global registered transfer adapters for each module""" +def _reset_registered_transfers() -> Generator: + """Reset global registered transfer adapters for each module.""" adapters = dict(transfer._registered_adapters) try: yield diff --git a/tests/transfer/test_basic_external_adapter.py b/tests/transfer/test_basic_external_adapter.py index 4d18a2a..ee3f636 100644 --- a/tests/transfer/test_basic_external_adapter.py +++ b/tests/transfer/test_basic_external_adapter.py @@ -1,15 +1,16 @@ -from typing import Any, Optional +"""Test basic_external transfer adapter functionality.""" +from typing import Any from urllib.parse import urlencode import pytest from giftless.storage import ExternalStorage -from giftless.storage.exc import ObjectNotFound +from giftless.storage.exc import ObjectNotFoundError from giftless.transfer import basic_external def test_factory_returns_object() -> None: - """Test that the basic_external factory returns the right object(s)""" + """Test that the basic_external factory returns the right object(s).""" base_url = "https://s4.example.com/" lifetime = 300 adapter = basic_external.factory( @@ -178,7 +179,7 @@ def test_download_action_extras_are_passed() -> None: class MockExternalStorageBackend(ExternalStorage): - """A mock adapter for the basic external transfer adapter + """Implementation of mock adapter for the basic external transfer adapter. Typically, "external" backends are cloud providers - so this backend can be used in testing to test the transfer adapter's behavior without @@ -198,7 +199,7 @@ def get_size(self, prefix: str, oid: str) -> int: try: return self.existing_objects[(prefix, oid)] except KeyError: - raise ObjectNotFound("Object does not exist") + raise ObjectNotFoundError("Object does not exist") from None def get_upload_action( self, @@ -206,7 +207,7 @@ def get_upload_action( oid: str, size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: return { "actions": { @@ -226,7 +227,7 @@ def get_download_action( oid: str, size: int, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> dict[str, Any]: return { "actions": { @@ -245,7 +246,7 @@ def _get_signed_url( prefix: str, oid: str, expires_in: int, - extra: Optional[dict[str, Any]] = None, + extra: dict[str, Any] | None = None, ) -> str: url = f"{self.base_url}{prefix}/{oid}?expires_in={expires_in}" if extra: diff --git a/tests/transfer/test_module.py b/tests/transfer/test_module.py index 7f424fb..2a8c454 100644 --- a/tests/transfer/test_module.py +++ b/tests/transfer/test_module.py @@ -1,6 +1,4 @@ -"""Test common transfer module functionality -""" -from typing import Any +"""Test common transfer module functionality.""" import pytest @@ -8,7 +6,7 @@ @pytest.mark.parametrize( - "register,requested,expected", + ("register", "requested", "expected"), [ (["basic"], ["basic"], "basic"), (["foobar", "basic", "bizbaz"], ["basic"], "basic"), @@ -16,7 +14,7 @@ (["foobar", "basic", "bizbaz"], ["bizbaz", "basic"], "bizbaz"), ], ) -@pytest.mark.usefixtures("reset_registered_transfers") +@pytest.mark.usefixtures("_reset_registered_transfers") def test_transfer_adapter_matching( register: list[str], requested: list[str], expected: str ) -> None: @@ -30,5 +28,5 @@ def test_transfer_adapter_matching( def test_transfer_adapter_matching_nomatch() -> None: for adapter in ["foobar", "basic", "bizbaz"]: transfer.register_adapter(adapter, transfer.TransferAdapter()) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unable to match"): transfer.match_transfer_adapter(["complex", "even-better"]) From e850fd59e59a9cfc203a8a6b63f14a6ef8760d34 Mon Sep 17 00:00:00 2001 From: adam Date: Wed, 10 Jan 2024 17:49:44 -0700 Subject: [PATCH 2/4] Unpin Sphinx, freshen reqs, fix up docs/tests --- giftless/storage/azure.py | 8 ++------ giftless/util.py | 6 ++---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/giftless/storage/azure.py b/giftless/storage/azure.py index 40775f8..0961eef 100644 --- a/giftless/storage/azure.py +++ b/giftless/storage/azure.py @@ -393,14 +393,10 @@ def _calculate_blocks(file_size: int, part_size: int) -> list[Block]: """Calculate the list of blocks in a blob. >>> _calculate_blocks(30, 10) - [Block(id=0, start=0, size=10), - Block(id=1, start=10, size=10), - Block(id=2, start=20, size=10)] + [Block(id=0, start=0, size=10), Block(id=1, start=10, size=10), Block(id=2, start=20, size=10)] >>> _calculate_blocks(28, 10) - [Block(id=0, start=0, size=10), - Block(id=1, start=10, size=10), - Block(id=2, start=20, size=8)] + [Block(id=0, start=0, size=10), Block(id=1, start=10, size=10), Block(id=2, start=20, size=8)] >>> _calculate_blocks(7, 10) [Block(id=0, start=0, size=7)] diff --git a/giftless/util.py b/giftless/util.py index 0ca7c37..a875083 100644 --- a/giftless/util.py +++ b/giftless/util.py @@ -63,12 +63,10 @@ def add_query_params(url: str, params: dict[str, Any]) -> str: """Safely add query params to a url that may or may not already contain query params. - >>> add_query_params( - 'https://example.org', {'param1': 'value1', 'param2': 'value2'}) + >>> add_query_params('https://example.org', {'param1': 'value1', 'param2': 'value2'}) 'https://example.org?param1=value1¶m2=value2' - >>> add_query_params( - 'https://example.org?param1=value1', {'param2': 'value2'}) + >>> add_query_params('https://example.org?param1=value1', {'param2': 'value2'}) # noqa[E501] 'https://example.org?param1=value1¶m2=value2' """ # noqa: E501 urlencoded_params = urlencode(params) From 1de5a0ca973cbb8abfb791e90678491d920688ab Mon Sep 17 00:00:00 2001 From: adam Date: Fri, 12 Jan 2024 16:03:00 -0700 Subject: [PATCH 3/4] tighten typing again --- giftless/app.py | 2 +- giftless/auth/__init__.py | 4 ++-- giftless/auth/identity.py | 2 +- giftless/config.py | 2 +- giftless/schema.py | 6 +++--- giftless/storage/azure.py | 2 +- giftless/storage/google_cloud.py | 6 +++--- giftless/storage/local_storage.py | 6 +++--- giftless/transfer/basic_external.py | 6 +++--- giftless/transfer/basic_streaming.py | 2 +- giftless/transfer/multipart.py | 4 ++-- giftless/util.py | 4 ++-- 12 files changed, 23 insertions(+), 23 deletions(-) diff --git a/giftless/app.py b/giftless/app.py index 8a61481..bb2162b 100644 --- a/giftless/app.py +++ b/giftless/app.py @@ -59,4 +59,4 @@ def _load_middleware(flask_app: Flask) -> None: wsgi_app = klass(wsgi_app, *args, **kwargs) log.debug(f"Loaded middleware: {klass}(*{args}, **{kwargs}") - flask_app.wsgi_app = wsgi_app + flask_app.wsgi_app = wsgi_app # type:ignore[method-assign] diff --git a/giftless/auth/__init__.py b/giftless/auth/__init__.py index 2020c88..f8a336e 100644 --- a/giftless/auth/__init__.py +++ b/giftless/auth/__init__.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable from functools import wraps -from typing import Any +from typing import Any, cast from flask import Flask, Request, current_app, g from flask import request as flask_request @@ -209,7 +209,7 @@ def _create_authenticator(spec: str | dict[str, Any]) -> Authenticator: log.debug(f"Creating authenticator using factory: {spec['factory']}") factory = get_callable(spec["factory"], __name__) options = spec.get("options", {}) - return factory(**options) + return cast(Authenticator, factory(**options)) authentication = Authentication() diff --git a/giftless/auth/identity.py b/giftless/auth/identity.py index 8c669e1..11055c4 100644 --- a/giftless/auth/identity.py +++ b/giftless/auth/identity.py @@ -17,7 +17,7 @@ def all(cls) -> set["Permission"]: PermissionTree = dict[ - str | None, dict[str | None], dict[str | None, set[Permission]] + str | None, dict[str | None, dict[str | None, set[Permission]]] ] diff --git a/giftless/config.py b/giftless/config.py index 149716b..0d541d7 100644 --- a/giftless/config.py +++ b/giftless/config.py @@ -5,7 +5,7 @@ import yaml from dotenv import load_dotenv -from figcan import Configuration, Extensible +from figcan import Configuration, Extensible # type:ignore[attr-defined] from flask import Flask ENV_PREFIX = "GIFTLESS_" diff --git a/giftless/schema.py b/giftless/schema.py index 0e06cd1..7012ae0 100644 --- a/giftless/schema.py +++ b/giftless/schema.py @@ -20,13 +20,13 @@ class Operation(Enum): download = "download" -class RefSchema(ma.Schema): +class RefSchema(ma.Schema): # type:ignore[name-defined] """ref field schema.""" name = fields.String(required=True) -class ObjectSchema(ma.Schema): +class ObjectSchema(ma.Schema): # type:ignore[name-defined] """object field schema.""" oid = fields.String(required=True) @@ -48,7 +48,7 @@ def set_extra_fields( return {"extra": extra, **rest} -class BatchRequest(ma.Schema): +class BatchRequest(ma.Schema): # type:ignore[name-defined] """batch request schema.""" operation = EnumField(Operation, required=True) diff --git a/giftless/storage/azure.py b/giftless/storage/azure.py index 0961eef..3591cdd 100644 --- a/giftless/storage/azure.py +++ b/giftless/storage/azure.py @@ -283,7 +283,7 @@ def _get_signed_url( blob_name=blob_name, credential=sas_token, ) - return blob_client.url + return str(blob_client.url) def _get_uncommitted_blocks( self, prefix: str, oid: str, blocks: list[Block] diff --git a/giftless/storage/google_cloud.py b/giftless/storage/google_cloud.py index bb25445..ec10198 100644 --- a/giftless/storage/google_cloud.py +++ b/giftless/storage/google_cloud.py @@ -6,7 +6,7 @@ import json import posixpath from datetime import timedelta -from typing import Any, BinaryIO +from typing import Any, BinaryIO, cast import google.auth from google.auth import impersonated_credentials @@ -70,14 +70,14 @@ def put(self, prefix: str, oid: str, data_stream: BinaryIO) -> int: def exists(self, prefix: str, oid: str) -> bool: bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.blob(self._get_blob_path(prefix, oid)) - return blob.exists() + return cast(bool, blob.exists()) def get_size(self, prefix: str, oid: str) -> int: bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.get_blob(self._get_blob_path(prefix, oid)) if blob is None: raise ObjectNotFoundError("Object does not exist") - return blob.size + return cast(int, blob.size) def get_upload_action( self, diff --git a/giftless/storage/local_storage.py b/giftless/storage/local_storage.py index b324a45..0178ff4 100644 --- a/giftless/storage/local_storage.py +++ b/giftless/storage/local_storage.py @@ -36,7 +36,7 @@ def get(self, prefix: str, oid: str) -> BinaryIO: def put(self, prefix: str, oid: str, data_stream: BinaryIO) -> int: path = self._get_path(prefix, oid) directory = path.parent - self._create_path(directory) + self._create_path(str(directory)) with path.open("bw") as dest: shutil.copyfileobj(data_stream, dest) return dest.tell() @@ -84,7 +84,7 @@ def _get_path(self, prefix: str, oid: str) -> Path: return Path(self.path) / prefix / oid @staticmethod - def _create_path(path: str) -> None: - path = Path(path) + def _create_path(spath: str) -> None: + path = Path(spath) if not path.is_dir(): path.mkdir(parents=True) diff --git a/giftless/transfer/basic_external.py b/giftless/transfer/basic_external.py index 9e84c35..01e7130 100644 --- a/giftless/transfer/basic_external.py +++ b/giftless/transfer/basic_external.py @@ -59,7 +59,7 @@ def upload( prefix, oid, size, self.action_lifetime, extra ) ) - if response.get("actions", {}).get("upload"): + if response.get("actions", {}).get("upload"): # type:ignore[attr-defined] response["authenticated"] = True headers = self._preauth_headers( organization, @@ -68,7 +68,7 @@ def upload( oid=oid, lifetime=self.VERIFY_LIFETIME, ) - response["actions"]["verify"] = { + response["actions"]["verify"] = { # type:ignore[index] "href": VerifyView.get_verify_url(organization, repo), "header": headers, "expires_in": self.VERIFY_LIFETIME, @@ -97,7 +97,7 @@ def download( except exc.StorageError as e: response["error"] = e.as_dict() - if response.get("actions", {}).get("download"): + if response.get("actions", {}).get("download"): # type:ignore[attr-defined] response["authenticated"] = True return response diff --git a/giftless/transfer/basic_streaming.py b/giftless/transfer/basic_streaming.py index d20f09d..b62432b 100644 --- a/giftless/transfer/basic_streaming.py +++ b/giftless/transfer/basic_streaming.py @@ -208,7 +208,7 @@ def download( repo: str, oid: str, size: int, - extra: dict[str, Any] | None, + extra: dict[str, Any] | None = None, ) -> dict: response = {"oid": oid, "size": size} diff --git a/giftless/transfer/multipart.py b/giftless/transfer/multipart.py index 2fd18a9..790ab07 100644 --- a/giftless/transfer/multipart.py +++ b/giftless/transfer/multipart.py @@ -57,7 +57,7 @@ def upload( oid=oid, lifetime=self.VERIFY_LIFETIME, ) - response["actions"]["verify"] = { + response["actions"]["verify"] = { # type: ignore[index] "href": VerifyView.get_verify_url(organization, repo), "header": headers, "expires_in": self.VERIFY_LIFETIME, @@ -86,7 +86,7 @@ def download( except exc.StorageError as e: response["error"] = e.as_dict() - if response.get("actions", {}).get("download"): + if response.get("actions", {}).get("download"): # type:ignore[attr-defined] response["authenticated"] = True return response diff --git a/giftless/util.py b/giftless/util.py index a875083..f72b00a 100644 --- a/giftless/util.py +++ b/giftless/util.py @@ -1,7 +1,7 @@ """Miscellanea.""" import importlib from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, cast from urllib.parse import urlencode @@ -28,7 +28,7 @@ def get_callable( "Expecting base_package to be set if only class name is provided" ) - return getattr(module, callable_name) + return cast(Callable, getattr(module, callable_name)) def to_iterable(val: Any) -> Iterable: From aa55b12bcb712d377649f31e287d27d71ec3ec8a Mon Sep 17 00:00:00 2001 From: adam Date: Fri, 12 Jan 2024 16:38:01 -0700 Subject: [PATCH 4/4] fix coverage --- tox.ini | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tox.ini b/tox.ini index 938f80b..6e0b98f 100644 --- a/tox.ini +++ b/tox.ini @@ -12,14 +12,6 @@ deps = -rrequirements/main.txt -rrequirements/dev.txt -[testenv:coverage-report] -description = Compile coverage from each test run. -skip_install = true -deps = coverage[toml]>=5.0.2 -depends = - py-coverage -commands = coverage report - [testenv:lint] description = Lint codebase by running pre-commit (Black, isort, Flake8) skip_install = true @@ -30,12 +22,15 @@ commands = pre-commit run --all-files [testenv:py] description = Run pytest commands = - pytest -vv {posargs} + pytest -vv {posargs} --cov=giftless -[testenv:py-coverage] -description = Run pytest with Docker prerequisites and coverage analysis -commands = - pytest -vv --cov=giftless --cov-branch --cov-report= {posargs} +[testenv:coverage-report] +description = Compile coverage from each test run. +skip_install = true +deps = coverage[toml]>=5.0.2 +depends = + py-coverage +commands = coverage report [testenv:docs] description = Build documentation (HTML) with Sphinx