Skip to content

Commit

Permalink
Add ruff for linting
Browse files Browse the repository at this point in the history
Apply ruff fixes
  • Loading branch information
yugokato committed Oct 1, 2024
1 parent 434df7e commit 8ac86e5
Show file tree
Hide file tree
Showing 16 changed files with 98 additions and 65 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ ci:
autofix_prs: false
autoupdate_schedule: quarterly
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.8"
hooks:
- id: ruff
args: ["--fix"]
types_or: [python]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
Expand Down
24 changes: 23 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ dependencies = [
"inflect==7.0.0",
"phonenumbers==8.13.45",
"pydantic-extra-types==2.9.0",
"pre-commit==3.6.2",
"pydantic[email]==2.9.2",
"PyYAML==6.0.1",
]
readme = "README.md"

[project.optional-dependencies]
dev = ["pre-commit==3.6.2", "ruff==0.6.8"]
app = [
"common-libs[dev]",
"Quart==0.19.4",
"quart-auth==0.9.0",
"quart-schema[pydantic]==0.19.1",

]
test = [
"common-libs[dev]",
"openapi-test-client[app]",
"pytest==8.3.2",
"pytest-lazy-fixtures==1.1.1",
Expand All @@ -49,3 +51,23 @@ profile = "black"

[tool.black]
line_length = 120

[tool.ruff]
line-length = 120
indent-width = 4

[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
]
ignore = ["E501", "E731", "E741", "F403"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
# We currently use `Optional` in a special way
"**/{clients,}/**/{api,models}/*" = ["UP007"]
2 changes: 1 addition & 1 deletion src/demo_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _register_blueprints(app, version: int):
from demo_app.handlers.error_handlers import bp_error_handler
from demo_app.handlers.request_handlers import bp_request_handler

bp_api = Blueprint(f"demo_app", __name__, url_prefix=f"/v{version}")
bp_api = Blueprint("demo_app", __name__, url_prefix=f"/v{version}")
bp_api.register_blueprint(bp_auth, name=bp_auth.name)
bp_api.register_blueprint(bp_user, name=bp_user.name)

Expand Down
29 changes: 14 additions & 15 deletions src/demo_app/api/user/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional

from pydantic import AnyUrl, BaseModel, EmailStr, Field
from quart_schema.pydantic import File
Expand All @@ -18,35 +17,35 @@ class UserTheme(Enum):


class UserQuery(BaseModel):
id: Optional[int] = None
email: Optional[EmailStr] = None
role: Optional[UserRole] = None
id: int | None = None
email: EmailStr | None = None
role: UserRole | None = None


class SocialLinks(BaseModel):
facebook: Optional[AnyUrl] = None
instagram: Optional[AnyUrl] = None
linkedin: Optional[AnyUrl] = None
github: Optional[AnyUrl] = None
facebook: AnyUrl | None = None
instagram: AnyUrl | None = None
linkedin: AnyUrl | None = None
github: AnyUrl | None = None


class Preferences(BaseModel):
theme: Optional[UserTheme] = UserTheme.LIGHT_MODE.value
language: Optional[str] = None
font_size: Optional[int] = Field(None, ge=8, le=40, multiple_of=2)
theme: UserTheme | None = UserTheme.LIGHT_MODE.value
language: str | None = None
font_size: int | None = Field(None, ge=8, le=40, multiple_of=2)


class Metadata(BaseModel):
preferences: Optional[Preferences] = None
social_links: Optional[SocialLinks] = None
preferences: Preferences | None = None
social_links: SocialLinks | None = None


class UserRequest(BaseModel):
first_name: str = Field(..., min_length=1, max_length=255)
last_name: str = Field(..., min_length=1, max_length=255)
email: EmailStr
role: UserRole
metadata: Optional[Metadata] = Field(default_factory=dict)
metadata: Metadata | None = Field(default_factory=dict)


class User(UserRequest):
Expand All @@ -55,4 +54,4 @@ class User(UserRequest):

class UserImage(BaseModel):
file: File
description: Optional[str] = None
description: str | None = None
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Callable, ParamSpec
from typing import TYPE_CHECKING, ParamSpec

from common_libs.clients.rest_client import RestResponse

Expand Down
11 changes: 6 additions & 5 deletions src/openapi_test_client/libraries/api/api_classes/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional
from collections.abc import Callable
from typing import TYPE_CHECKING

from common_libs.clients.rest_client import RestResponse
from common_libs.logging import get_logger
Expand All @@ -18,10 +19,10 @@
class APIBase(metaclass=ABCMeta):
"""Base API class"""

app_name: Optional[str] = None
app_name: str | None = None
is_documented: bool = True
is_deprecated: bool = False
endpoints: Optional[list[Endpoint]] = None
endpoints: list[Endpoint] | None = None

def __init__(self, api_client: APIClientType):
if self.app_name != api_client.app_name:
Expand Down Expand Up @@ -51,8 +52,8 @@ def pre_request_hook(self, endpoint: Endpoint, *path_params, **params):
def post_request_hook(
self,
endpoint: Endpoint,
response: Optional[RestResponse],
request_exception: Optional[RequestException],
response: RestResponse | None,
request_exception: RequestException | None,
*path_params,
**params,
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Callable, ParamSpec
from typing import ParamSpec

from common_libs.clients.rest_client import RestResponse
from common_libs.logging import get_logger
Expand Down
15 changes: 8 additions & 7 deletions src/openapi_test_client/libraries/api/api_functions/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, update_wrapper, wraps
from threading import RLock
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast

from common_libs.ansi_colors import ColorCodes, color
from common_libs.clients.rest_client import RestResponse
Expand Down Expand Up @@ -49,8 +50,8 @@ class Endpoint:
path: str
func_name: str
model: type[EndpointModel]
url: Optional[str] = None # Available only for an endpoint object accessed via an API client instance
content_type: Optional[str] = None
url: str | None = None # Available only for an endpoint object accessed via an API client instance
content_type: str | None = None
is_public: bool = False
is_documented: bool = True
is_deprecated: bool = False
Expand Down Expand Up @@ -409,7 +410,7 @@ def __init__(
self.is_deprecated = False
self.__decorators = []

def __get__(self, instance: Optional[APIClassType], owner: type[APIClassType]) -> EndpointFunc:
def __get__(self, instance: APIClassType | None, owner: type[APIClassType]) -> EndpointFunc:
"""Return an EndpointFunc object"""
key = (self.original_func.__name__, instance, owner)
with EndpointHandler._lock:
Expand Down Expand Up @@ -442,14 +443,14 @@ class EndpointFunc:
All parameters passed to the original API class function call will be passed through to the __call__()
"""

def __init__(self, endpoint_handler: EndpointHandler, instance: Optional[APIClassType], owner: type[APIClassType]):
def __init__(self, endpoint_handler: EndpointHandler, instance: APIClassType | None, owner: type[APIClassType]):
"""Initialize endpoint function"""
if not issubclass(owner, APIBase):
raise NotImplementedError(f"Unsupported API class: {owner}")

self.method = endpoint_handler.method
self.path = endpoint_handler.path
self.rest_client: Optional[RestClient]
self.rest_client: RestClient | None
if instance:
self.api_client = instance.api_client
self.rest_client = self.api_client.rest_client
Expand Down Expand Up @@ -643,7 +644,7 @@ def with_retry(
f = retry_on(condition, num_retry=num_retry, retry_after=retry_after, safe_methods_only=False)(self)
return f(*args, **kwargs)

def get_usage(self) -> Optional[str]:
def get_usage(self) -> str | None:
"""Get OpenAPI spec definition for the endpoint"""
if self.api_client and self.endpoint.is_documented:
return self.api_client.api_spec.get_endpoint_usage(self.endpoint)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import re
from collections import OrderedDict
from typing import TYPE_CHECKING, Annotated, Any, Optional, get_args, get_origin
from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin

from common_libs.clients.rest_client.utils import get_supported_request_parameters
from common_libs.logging import get_logger
Expand Down Expand Up @@ -224,7 +224,7 @@ def generate_rest_func_params(
# We will set the Content-type value using from the OpenAPI specs for this case, unless the header is explicitly
# set by a user. Otherwise, requests lib will automatically handle this part
if (data := rest_func_params.get("data")) and (
isinstance(data, (str, bytes)) and not specified_content_type_header and endpoint.content_type
isinstance(data, str | bytes) and not specified_content_type_header and endpoint.content_type
):
rest_func_params.setdefault("headers", {}).update({"Content-Type": endpoint.content_type})

Expand All @@ -233,7 +233,7 @@ def generate_rest_func_params(

def _get_specified_content_type_header(
requests_lib_options: dict[str, Any], session_headers: dict[str, str]
) -> Optional[str]:
) -> str | None:
"""Get Content-Type header value set for the request or for the current session"""
request_headers = requests_lib_options.get("headers", {})
content_type_header = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from copy import deepcopy
from dataclasses import MISSING, Field, field, make_dataclass
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast

from common_libs.logging import get_logger

Expand Down Expand Up @@ -131,7 +131,7 @@ def _parse_parameter_objects(
method: str,
parameter_objects: list[dict[str, Any]],
path_param_fields: list[tuple[str, Any]],
body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]],
body_or_query_param_fields: list[tuple[str, Any, Field | None]],
):
"""Parse parameter objects
Expand Down Expand Up @@ -205,8 +205,8 @@ def _parse_parameter_objects(


def _parse_request_body_object(
request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]]
) -> Optional[str]:
request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field | None]]
) -> str | None:
"""Parse request body object
https://swagger.io/specification/#request-body-object
Expand Down Expand Up @@ -249,7 +249,7 @@ def parse_schema_obj(obj: dict[str, Any]):
if _is_file_param(content_type, param_def):
param_type = File
if not param_def.is_required:
param_type = Optional[param_type]
param_type = param_type | None
body_or_query_param_fields.append((param_name, param_type, field(default=None)))
else:
existing_param_names = [x[0] for x in body_or_query_param_fields]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _is_param_model(obj: Any) -> bool:
return _is_param_model(inner_type)


def get_param_model(annotated_type: Any) -> Optional[ParamModel | list[ParamModel]]:
def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None:
"""Returns a param model from the annotated type, if there is any
:param annotated_type: Annotated type
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_reserved_model_names() -> list[str]:
custom_param_annotation_names = [
x.__name__
for x in mod.__dict__.values()
if inspect.isclass(x) and issubclass(x, (ParamAnnotationType, DataclassModel))
if inspect.isclass(x) and issubclass(x, ParamAnnotationType | DataclassModel)
]
typing_class_names = [x.__name__ for x in [Any, Optional, Annotated, Literal, Union]]
return custom_param_annotation_names + typing_class_names
Expand All @@ -106,7 +106,7 @@ def create_model_from_param_def(
:param model_name: The model name
:param param_def: ParamDef generated from an OpenAPI parameter object
"""
if not isinstance(param_def, (ParamDef, ParamDef.ParamGroup, ParamDef.UnknownType)):
if not isinstance(param_def, ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType):
raise ValueError(f"Invalid param_def type: {type(param_def)}")

if isinstance(param_def, ParamDef) and param_def.is_array and "items" in param_def:
Expand Down Expand Up @@ -273,7 +273,7 @@ def visit(model_name: str):
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))


def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Optional[Field]]]):
def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field | None]]):
"""Clean illegal model field name and annotate the field type with Alias class
:param param_fields: fields value to be passed to make_dataclass()
Expand Down
Loading

0 comments on commit 8ac86e5

Please sign in to comment.