Skip to content

Commit

Permalink
strict typing for xmlrpc (pypi#14306)
Browse files Browse the repository at this point in the history
resolves pypi#14302
  • Loading branch information
ewdurbin authored Aug 8, 2023
1 parent 47947ba commit c94bbf6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tests/functional/legacy_api/test_xmlrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_multiple_garbage_types(app_config, webtest):
xmlrpc.client.Fault,
match=(
r"client error; \('since',\): value is not a valid integer; "
r"\('with_ids',\): value could not be parsed to a boolean"
r"\('with_ids',\): value is not a valid boolean"
),
):
webtest.xmlrpc("/pypi", "changelog", "wrong!", "also wrong!")
28 changes: 16 additions & 12 deletions warehouse/legacy/api/xmlrpc/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Mapping

from packaging.utils import canonicalize_name
from pydantic import ValidationError
from pydantic import StrictBool, StrictInt, StrictStr, ValidationError
from pydantic.decorator import ValidatedFunction
from pyramid.httpexceptions import HTTPTooManyRequests
from pyramid.view import view_config
Expand Down Expand Up @@ -230,7 +230,11 @@ def exception_view(exc, request):


@xmlrpc_method(method="search")
def search(request, spec: Mapping[str, str | list[str]], operator: str = "and"):
def search(
request,
spec: Mapping[StrictStr, StrictStr | list[StrictStr]],
operator: StrictStr = "and",
):
domain = request.registry.settings.get("warehouse.domain", request.domain)
raise XMLRPCWrappedError(
RuntimeError(
Expand All @@ -254,12 +258,12 @@ def list_packages_with_serial(request):


@xmlrpc_method(method="package_hosting_mode")
def package_hosting_mode(request, package_name: str):
def package_hosting_mode(request, package_name: StrictStr):
return "pypi-only"


@xmlrpc_method(method="user_packages")
def user_packages(request, username: str):
def user_packages(request, username: StrictStr):
roles = (
request.db.query(Role)
.join(User)
Expand All @@ -272,7 +276,7 @@ def user_packages(request, username: str):


@xmlrpc_method(method="top_packages")
def top_packages(request, num=None):
def top_packages(request, num: StrictInt | None = None):
raise XMLRPCWrappedError(
RuntimeError(
"This API has been removed. Use BigQuery instead. "
Expand All @@ -282,7 +286,7 @@ def top_packages(request, num=None):


@xmlrpc_cache_by_project(method="package_releases")
def package_releases(request, package_name: str, show_hidden: bool = False):
def package_releases(request, package_name: StrictStr, show_hidden: StrictBool = False):
try:
project = (
request.db.query(Project)
Expand Down Expand Up @@ -316,7 +320,7 @@ def package_data(request, package_name, version):


@xmlrpc_cache_by_project(method="release_data")
def release_data(request, package_name: str, version: str):
def release_data(request, package_name: StrictStr, version: StrictStr):
try:
release = (
request.db.query(Release)
Expand Down Expand Up @@ -386,7 +390,7 @@ def package_urls(request, package_name, version):


@xmlrpc_cache_by_project(method="release_urls")
def release_urls(request, package_name: str, version: str):
def release_urls(request, package_name: StrictStr, version: StrictStr):
files = (
request.db.query(File)
.join(Release)
Expand Down Expand Up @@ -424,7 +428,7 @@ def release_urls(request, package_name: str, version: str):


@xmlrpc_cache_by_project(method="package_roles")
def package_roles(request, package_name: str):
def package_roles(request, package_name: StrictStr):
roles = (
request.db.query(Role)
.join(User)
Expand All @@ -442,7 +446,7 @@ def changelog_last_serial(request):


@xmlrpc_method(method="changelog_since_serial")
def changelog_since_serial(request, serial: int):
def changelog_since_serial(request, serial: StrictInt):
entries = (
request.db.query(JournalEntry)
.filter(JournalEntry.id > serial)
Expand All @@ -463,7 +467,7 @@ def changelog_since_serial(request, serial: int):


@xmlrpc_method(method="changelog")
def changelog(request, since: int, with_ids: bool = False):
def changelog(request, since: StrictInt, with_ids: StrictBool = False):
since_dt = datetime.datetime.utcfromtimestamp(since)
entries = (
request.db.query(JournalEntry)
Expand All @@ -490,7 +494,7 @@ def changelog(request, since: int, with_ids: bool = False):


@xmlrpc_method(method="browse")
def browse(request, classifiers: list[str]):
def browse(request, classifiers: list[StrictStr]):
classifiers_q = (
request.db.query(Classifier)
.filter(Classifier.classifier.in_(classifiers))
Expand Down

0 comments on commit c94bbf6

Please sign in to comment.