Skip to content

Commit

Permalink
Merge pull request #159 from the-mama-ai/fix-preauth
Browse files Browse the repository at this point in the history
Fix authenticated flag in a transfer adapter response
  • Loading branch information
athornton authored May 7, 2024
2 parents 75b046b + fd4812a commit 67ee430
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 49 deletions.
2 changes: 1 addition & 1 deletion giftless/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_authz_query_params(
oid: str | None = None,
lifetime: int | None = None,
) -> dict[str, str]:
"""Authorize an action by adding credientaisl to the query string."""
"""Authorize an action by adding credentials to the query string."""

@abc.abstractmethod
def get_authz_header(
Expand Down
49 changes: 31 additions & 18 deletions giftless/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,17 @@ class GithubIdentity(Identity):
"""

def __init__(
self, login: str, id_: str, name: str, email: str, *, cc: CacheConfig
self,
login: str,
github_id: str,
name: str,
email: str,
*,
cc: CacheConfig,
) -> None:
super().__init__()
self.login = login
self.id = id_
self.id = login
self.github_id = github_id
self.name = name
self.email = email

Expand All @@ -274,30 +280,36 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float:

def __repr__(self) -> str:
return (
f"<{self.__class__.__name__}"
f"login:{self.login} id:{self.id} name:{self.name}>"
f"<{self.__class__.__name__} "
f"id:{self.id} github_id:{self.github_id} name:{self.name}>"
)

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and (self.login, self.id) == (
other.login,
other.id,
)
return isinstance(other, self.__class__) and (
self.id,
self.github_id,
) == (other.id, other.github_id)

def __hash__(self) -> int:
return hash((self.login, self.id))
return hash((self.id, self.github_id))

def permissions(
self, org: str, repo: str, *, authoritative: bool = False
) -> set[Permission] | None:
"""Return user's permission set for an org/repo."""
key = cachetools.keys.hashkey(org, repo)
with self._auth_cache_lock:
# first check if the permissions are in the proxy cache
if authoritative:
# pop the entry from the proxy cache to be stored properly
permission = self._auth_cache_read_proxy.pop(key, None)
else:
# just get it when only peeking
permission = self._auth_cache_read_proxy.get(key)
# if not found in the proxy, check the regular auth cache
if permission is None:
return self._auth_cache.get(key)
# try moving proxy permissions to the regular cache
if authoritative:
with suppress(ValueError):
self._auth_cache[key] = permission
Expand All @@ -306,7 +318,10 @@ def permissions(
def authorize(
self, org: str, repo: str, permissions: set[Permission] | None
) -> None:
"""Save user's permission set for an org/repo."""
key = cachetools.keys.hashkey(org, repo)
# put the discovered permissions into the proxy cache
# to ensure at least one successful 'authoritative' read
with self._auth_cache_lock:
self._auth_cache_read_proxy[key] = (
permissions if permissions is not None else set()
Expand Down Expand Up @@ -439,29 +454,27 @@ def _authorize(self, ctx: CallContext, user: GithubIdentity) -> None:
if (permissions := user.permissions(org, repo)) is not None:
perm_list = self._perm_list(permissions)
_logger.debug(
f"{user.login} is already temporarily authorized for "
f"{user.id} is already temporarily authorized for "
f"{org_repo}: {perm_list}"
)
else:
_logger.debug(
f"Checking {user.login}'s permissions for {org_repo}"
)
_logger.debug(f"Checking {user.id}'s permissions for {org_repo}")
try:
repo_data = self._api_get(
f"/repos/{org_repo}/collaborators/{user.login}/permission",
f"/repos/{org_repo}/collaborators/{user.id}/permission",
ctx,
)
except requests.exceptions.RequestException as e:
msg = (
f"Failed to find {user.login}'s permissions for "
f"Failed to find {user.id}'s permissions for "
f"{org_repo}: {e}"
)
_logger.warning(msg)
raise Unauthorized(msg) from None

gh_permission = repo_data.get("permission")
_logger.debug(
f"User {user.login} has '{gh_permission}' GitHub permission "
f"User {user.id} has '{gh_permission}' GitHub permission "
f"for {org_repo}"
)
permissions = set()
Expand All @@ -472,7 +485,7 @@ def _authorize(self, ctx: CallContext, user: GithubIdentity) -> None:
perm_list = self._perm_list(permissions)
ttl = user.cache_ttl(permissions)
_logger.debug(
f"Authorizing {user.login} (for {ttl}s) for "
f"Authorizing {user.id} (for {ttl}s) for "
f"{org_repo}: {perm_list}"
)
user.authorize(org, repo, permissions)
Expand Down
44 changes: 24 additions & 20 deletions giftless/transfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PreAuthorizedActionAuthenticator,
authentication,
)
from giftless.auth.identity import Identity
from giftless.util import add_query_params, get_callable
from giftless.view import ViewProvider

Expand Down Expand Up @@ -82,6 +83,25 @@ def __init__(self) -> None:
def set_auth_module(self, auth_module: Authentication) -> None:
self._auth_module = auth_module

@property
def _preauth_handler_and_identity(
self,
) -> tuple[PreAuthorizedActionAuthenticator | None, Identity | None]:
if (
self._auth_module is None
or self._auth_module.preauth_handler is None
):
return None, None
handler = cast(
PreAuthorizedActionAuthenticator, self._auth_module.preauth_handler
)
identity = self._auth_module.get_identity()
return handler, identity

@property
def _provides_preauth(self) -> bool:
return None not in self._preauth_handler_and_identity

def _preauth_url(
self,
original_url: str,
Expand All @@ -91,16 +111,8 @@ def _preauth_url(
oid: str | None = None,
lifetime: int | None = None,
) -> str:
if self._auth_module is None:
return original_url
if self._auth_module.preauth_handler is None:
return original_url

handler = cast(
PreAuthorizedActionAuthenticator, self._auth_module.preauth_handler
)
identity = self._auth_module.get_identity()
if identity is None:
handler, identity = self._preauth_handler_and_identity
if handler is None or identity is None:
return original_url

params = handler.get_authz_query_params(
Expand All @@ -117,18 +129,10 @@ def _preauth_headers(
oid: str | None = None,
lifetime: int | None = None,
) -> dict[str, str]:
if self._auth_module is None:
handler, identity = self._preauth_handler_and_identity
if handler is None or identity is None:
return {}
if self._auth_module.preauth_handler is None:
return {}

handler = cast(
PreAuthorizedActionAuthenticator, self._auth_module.preauth_handler
)

identity = self._auth_module.get_identity()
if identity is None:
return {}
return handler.get_authz_header(
identity, org, repo, actions, oid, lifetime=lifetime
)
Expand Down
4 changes: 2 additions & 2 deletions giftless/transfer/basic_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def upload(
)
)
if response.get("actions", {}).get("upload"): # type:ignore[attr-defined]
response["authenticated"] = True
response["authenticated"] = self._provides_preauth
headers = self._preauth_headers(
organization,
repo,
Expand Down Expand Up @@ -98,7 +98,7 @@ def download(
response["error"] = e.as_dict()

if response.get("actions", {}).get("download"): # type:ignore[attr-defined]
response["authenticated"] = True
response["authenticated"] = self._provides_preauth

return response

Expand Down
4 changes: 2 additions & 2 deletions giftless/transfer/basic_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def upload(
"expires_in": self.VERIFY_LIFETIME,
},
}
response["authenticated"] = True
response["authenticated"] = self._provides_preauth

return response

Expand Down Expand Up @@ -250,7 +250,7 @@ def download(
"expires_in": self.action_lifetime,
}
}
response["authenticated"] = True
response["authenticated"] = self._provides_preauth

return response

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ build-backend = "setuptools.build_meta"
# Use single-quoted strings so TOML treats the string like a Python r-string
# Multi-line strings are implicitly treated by black as regular expressions

[tool.setuptools.packages.find]
include = ["giftless"]

[tool.coverage.run]
parallel = true
branch = true
Expand Down
9 changes: 7 additions & 2 deletions tests/auth/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,14 @@ def test_github_identity_core() -> None:
user_dict = DEFAULT_USER_DICT | {"other_field": "other_value"}
cache_cfg = DEFAULT_CONFIG.cache
user = gh.GithubIdentity.from_dict(user_dict, cc=cache_cfg)
assert (user.login, user.id, user.name, user.email) == DEFAULT_USER_ARGS
assert (
user.id,
user.github_id,
user.name,
user.email,
) == DEFAULT_USER_ARGS
assert all(arg in repr(user) for arg in DEFAULT_USER_ARGS[:3])
assert hash(user) == hash((user.login, user.id))
assert hash(user) == hash((user.id, user.github_id))

args2 = (*DEFAULT_USER_ARGS[:2], "spammer", "[email protected]")
user2 = gh.GithubIdentity(*args2, cc=cache_cfg)
Expand Down
8 changes: 4 additions & 4 deletions tests/transfer/test_basic_external_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_upload_action_new_file(app: flask.Flask) -> None:
assert response == {
"oid": "abcdef123456",
"size": 1234,
"authenticated": True,
"authenticated": False,
"actions": {
"upload": {
"href": "https://cloudstorage.example.com/myorg/myrepo/abcdef123456?expires_in=900",
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_upload_action_extras_are_passed(app: flask.Flask) -> None:
assert response == {
"oid": "abcdef123456",
"size": 1234,
"authenticated": True,
"authenticated": False,
"actions": {
"upload": {
"href": "https://cloudstorage.example.com/myorg/myrepo/abcdef123456?expires_in=900&filename=foo.csv",
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_download_action_existing_file() -> None:
assert response == {
"oid": "abcdef123456",
"size": 1234,
"authenticated": True,
"authenticated": False,
"actions": {
"download": {
"href": "https://cloudstorage.example.com/myorg/myrepo/abcdef123456?expires_in=900",
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_download_action_extras_are_passed() -> None:
assert response == {
"oid": "abcdef123456",
"size": 1234,
"authenticated": True,
"authenticated": False,
"actions": {
"download": {
"href": "https://cloudstorage.example.com/myorg/myrepo/abcdef123456?expires_in=900&filename=foo.csv",
Expand Down

0 comments on commit 67ee430

Please sign in to comment.