Skip to content

Commit

Permalink
fix: typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravencentric committed Sep 17, 2024
1 parent 64a9e2a commit d05c1cc
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 21 deletions.
42 changes: 23 additions & 19 deletions src/pynyaa/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Self

from pynyaa._compat import IntEnum, StrEnum
from pynyaa._types import CategoryID, CategoryName, SortName
from pynyaa._types import CategoryID, CategoryLiteral, SortByLiteral
from pynyaa._utils import get_category_id_from_name


Expand Down Expand Up @@ -67,31 +67,31 @@ def id(self) -> CategoryID:

@overload
@classmethod
def get(cls, key: CategoryName, default: CategoryName = "All") -> Self: ...
def get(cls, key: CategoryLiteral, default: CategoryLiteral = "All") -> Self: ...

@overload
@classmethod
def get(cls, key: CategoryName, default: str = "All") -> Self: ...
def get(cls, key: CategoryLiteral, default: str = "All") -> Self: ...

@overload
@classmethod
def get(cls, key: str, default: CategoryName = "All") -> Self: ...
def get(cls, key: str, default: CategoryLiteral = "All") -> Self: ...

@overload
@classmethod
def get(cls, key: str, default: str = "All") -> Self: ...

@classmethod
def get(cls, key: CategoryName | str, default: CategoryName | str = "All") -> Self:
def get(cls, key: CategoryLiteral | str, default: CategoryLiteral | str = "All") -> Self:
"""
Get the `Category` by its name (case-insensitive).
Return the default if the key is missing or invalid.
Parameters
----------
key : CategoryName | str
key : CategoryLiteral | str
The key to retrieve.
default : CategoryName | str, optional
default : CategoryLiteral | str, optional
The default value to return if the key is missing or invalid.
Returns
Expand All @@ -102,7 +102,7 @@ def get(cls, key: CategoryName | str, default: CategoryName | str = "All") -> Se
match key:
case str():
for category in cls:
if category.value.casefold() == key.casefold():
if (category.value.casefold() == key.casefold()) or (category.name.casefold() == key.casefold()):
return category
else:
return cls(default)
Expand All @@ -120,46 +120,50 @@ class SortBy(BaseStrEnum):

@overload
@classmethod
def get(cls, key: SortName, default: SortName = "datetime") -> Self: ...
def get(cls, key: SortByLiteral, default: SortByLiteral = "datetime") -> Self: ...

@overload
@classmethod
def get(cls, key: SortName, default: str = "datetime") -> Self: ...
def get(cls, key: SortByLiteral, default: str = "datetime") -> Self: ...

@overload
@classmethod
def get(cls, key: str, default: SortName = "datetime") -> Self: ...
def get(cls, key: str, default: SortByLiteral = "datetime") -> Self: ...

@overload
@classmethod
def get(cls, key: str, default: str = "datetime") -> Self: ...

@classmethod
def get(cls, key: SortName | str, default: SortName | str = "datetime") -> Self:
def get(cls, key: SortByLiteral | str, default: SortByLiteral | str = "datetime") -> Self:
"""
Get the `SortBy` by its name (case-insensitive).
Return the default if the key is missing or invalid.
Parameters
----------
key : SortName | str
key : SortByLiteral | str
The key to retrieve.
default : SortName | str, optional
default : SortByLiteral | str, optional
The default value to return if the key is missing or invalid.
Returns
-------
Category
The `SortBy` corresponding to the key.
"""

# "datetime" doesn't actually exist, it's just an alias for "id"
default = "id" if default.casefold() == "datetime" else default.casefold()
key = key.casefold()
key = "id" if key.casefold() == "datetime" else key.casefold()

match key:
case "comments" | "size" | "id" | "seeders" | "leechers" | "downloads":
return cls(key)
case "datetime":
return cls("id")
case str():
for category in cls:
if (category.value.casefold() == key.casefold()) or (category.name.casefold() == key.casefold()):
return category
else:
return cls(default)
case _:
return cls(default)

Expand Down
33 changes: 31 additions & 2 deletions src/pynyaa/_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
from __future__ import annotations

from typing import Annotated, Literal
from typing import Annotated, Literal, Union

from pydantic import AnyUrl, UrlConstraints

MagnetUrl = Annotated[AnyUrl, UrlConstraints(allowed_schemes=["magnet"])]
"""Url that only allows magnets."""

CategoryName = Literal[
"ALL",
"ANIME",
"ANIME_MUSIC_VIDEO",
"ANIME_ENGLISH_TRANSLATED",
"ANIME_NON_ENGLISH_TRANSLATED",
"ANIME_RAW",
"AUDIO",
"AUDIO_LOSSLESS",
"AUDIO_LOSSY",
"LITERATURE",
"LITERATURE_ENGLISH_TRANSLATED",
"LITERATURE_NON_ENGLISH_TRANSLATED",
"LITERATURE_RAW",
"LIVE_ACTION",
"LIVE_ACTION_ENGLISH_TRANSLATED",
"LIVE_ACTION_IDOL_PROMOTIONAL_VIDEO",
"LIVE_ACTION_NON_ENGLISH_TRANSLATED",
"LIVE_ACTION_RAW",
"PICTURES",
"PICTURES_GRAPHICS",
"PICTURES_PHOTOS",
"SOFTWARE",
"SOFTWARE_APPLICATIONS",
"SOFTWARE_GAMES",
]

CategoryValue = Literal[
"All",
"Anime",
"Anime - Anime Music Video",
Expand All @@ -34,6 +61,8 @@
"Software - Games",
]

CategoryLiteral = Union[CategoryName, CategoryValue]

CategoryID = Literal[
"0_0",
"1_0",
Expand Down Expand Up @@ -61,4 +90,4 @@
"6_2",
]

SortName = Literal["comments", "size", "id", "datetime", "seeders", "leechers", "downloads"]
SortByLiteral = Literal["comments", "size", "id", "datetime", "seeders", "leechers", "downloads"]
100 changes: 100 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from typing import get_args

from pynyaa._types import CategoryID, CategoryName, CategoryValue, SortByLiteral


def test_category_name_literals() -> None:
assert get_args(CategoryName) == (
"ALL",
"ANIME",
"ANIME_MUSIC_VIDEO",
"ANIME_ENGLISH_TRANSLATED",
"ANIME_NON_ENGLISH_TRANSLATED",
"ANIME_RAW",
"AUDIO",
"AUDIO_LOSSLESS",
"AUDIO_LOSSY",
"LITERATURE",
"LITERATURE_ENGLISH_TRANSLATED",
"LITERATURE_NON_ENGLISH_TRANSLATED",
"LITERATURE_RAW",
"LIVE_ACTION",
"LIVE_ACTION_ENGLISH_TRANSLATED",
"LIVE_ACTION_IDOL_PROMOTIONAL_VIDEO",
"LIVE_ACTION_NON_ENGLISH_TRANSLATED",
"LIVE_ACTION_RAW",
"PICTURES",
"PICTURES_GRAPHICS",
"PICTURES_PHOTOS",
"SOFTWARE",
"SOFTWARE_APPLICATIONS",
"SOFTWARE_GAMES",
)


def test_category_value_literals() -> None:
assert get_args(CategoryValue) == (
"All",
"Anime",
"Anime - Anime Music Video",
"Anime - English-translated",
"Anime - Non-English-translated",
"Anime - Raw",
"Audio",
"Audio - Lossless",
"Audio - Lossy",
"Literature",
"Literature - English-translated",
"Literature - Non-English-translated",
"Literature - Raw",
"Live Action",
"Live Action - English-translated",
"Live Action - Idol/Promotional Video",
"Live Action - Non-English-translated",
"Live Action - Raw",
"Pictures",
"Pictures - Graphics",
"Pictures - Photos",
"Software",
"Software - Applications",
"Software - Games",
)


def test_category_id_literal() -> None:
assert get_args(CategoryID) == (
"0_0",
"1_0",
"1_1",
"1_2",
"1_3",
"1_4",
"2_0",
"2_1",
"2_2",
"3_0",
"3_1",
"3_2",
"3_3",
"4_0",
"4_1",
"4_2",
"4_3",
"4_4",
"5_0",
"5_1",
"5_2",
"6_0",
"6_1",
"6_2",
)


def test_category_literals_length() -> None:
assert len(get_args(CategoryName)) == len(get_args(CategoryValue)) == len(get_args(CategoryID)) == 24


def test_sort_by_literals() -> None:
assert get_args(SortByLiteral) == ("comments", "size", "id", "datetime", "seeders", "leechers", "downloads")

0 comments on commit d05c1cc

Please sign in to comment.