diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f88ec01..80c3a3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,12 @@ ci: autofix_prs: false autoupdate_schedule: quarterly repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.6.8" + hooks: + - id: ruff + args: ["--fix"] + types_or: [python] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index f847c0c..545676b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,20 +10,22 @@ dependencies = [ "inflect==7.0.0", "phonenumbers==8.13.45", "pydantic-extra-types==2.9.0", - "pre-commit==3.6.2", "pydantic[email]==2.9.2", "PyYAML==6.0.1", ] readme = "README.md" [project.optional-dependencies] +dev = ["pre-commit==3.6.2", "ruff==0.6.8"] app = [ + "common-libs[dev]", "Quart==0.19.4", "quart-auth==0.9.0", "quart-schema[pydantic]==0.19.1", ] test = [ + "common-libs[dev]", "openapi-test-client[app]", "pytest==8.3.2", "pytest-lazy-fixtures==1.1.1", @@ -49,3 +51,23 @@ profile = "black" [tool.black] line_length = 120 + +[tool.ruff] +line-length = 120 +indent-width = 4 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", +] +ignore = ["E501", "E731", "E741", "F403"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +# We currently use `Optional` in a special way +"**/{clients,}/**/{api,models}/*" = ["UP007"] diff --git a/src/demo_app/__init__.py b/src/demo_app/__init__.py index 8c15ebd..cb13918 100644 --- a/src/demo_app/__init__.py +++ b/src/demo_app/__init__.py @@ -33,7 +33,7 @@ def _register_blueprints(app, version: int): from demo_app.handlers.error_handlers import bp_error_handler from demo_app.handlers.request_handlers import bp_request_handler - bp_api = Blueprint(f"demo_app", __name__, url_prefix=f"/v{version}") + bp_api = Blueprint("demo_app", __name__, url_prefix=f"/v{version}") bp_api.register_blueprint(bp_auth, name=bp_auth.name) bp_api.register_blueprint(bp_user, name=bp_user.name) diff --git a/src/demo_app/api/user/models.py b/src/demo_app/api/user/models.py index 5e2ceb7..cb5cd2e 100644 --- a/src/demo_app/api/user/models.py +++ b/src/demo_app/api/user/models.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional from pydantic import AnyUrl, BaseModel, EmailStr, Field from quart_schema.pydantic import File @@ -18,27 +17,27 @@ class UserTheme(Enum): class UserQuery(BaseModel): - id: Optional[int] = None - email: Optional[EmailStr] = None - role: Optional[UserRole] = None + id: int | None = None + email: EmailStr | None = None + role: UserRole | None = None class SocialLinks(BaseModel): - facebook: Optional[AnyUrl] = None - instagram: Optional[AnyUrl] = None - linkedin: Optional[AnyUrl] = None - github: Optional[AnyUrl] = None + facebook: AnyUrl | None = None + instagram: AnyUrl | None = None + linkedin: AnyUrl | None = None + github: AnyUrl | None = None class Preferences(BaseModel): - theme: Optional[UserTheme] = UserTheme.LIGHT_MODE.value - language: Optional[str] = None - font_size: Optional[int] = Field(None, ge=8, le=40, multiple_of=2) + theme: UserTheme | None = UserTheme.LIGHT_MODE.value + language: str | None = None + font_size: int | None = Field(None, ge=8, le=40, multiple_of=2) class Metadata(BaseModel): - preferences: Optional[Preferences] = None - social_links: Optional[SocialLinks] = None + preferences: Preferences | None = None + social_links: SocialLinks | None = None class UserRequest(BaseModel): @@ -46,7 +45,7 @@ class UserRequest(BaseModel): last_name: str = Field(..., min_length=1, max_length=255) email: EmailStr role: UserRole - metadata: Optional[Metadata] = Field(default_factory=dict) + metadata: Metadata | None = Field(default_factory=dict) class User(UserRequest): @@ -55,4 +54,4 @@ class User(UserRequest): class UserImage(BaseModel): file: File - description: Optional[str] = None + description: str | None = None diff --git a/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py b/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py index 4639763..fa0f461 100644 --- a/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py +++ b/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Callable, ParamSpec +from typing import TYPE_CHECKING, ParamSpec from common_libs.clients.rest_client import RestResponse diff --git a/src/openapi_test_client/libraries/api/api_classes/base.py b/src/openapi_test_client/libraries/api/api_classes/base.py index 6db1522..f5e1dc1 100644 --- a/src/openapi_test_client/libraries/api/api_classes/base.py +++ b/src/openapi_test_client/libraries/api/api_classes/base.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING from common_libs.clients.rest_client import RestResponse from common_libs.logging import get_logger @@ -18,10 +19,10 @@ class APIBase(metaclass=ABCMeta): """Base API class""" - app_name: Optional[str] = None + app_name: str | None = None is_documented: bool = True is_deprecated: bool = False - endpoints: Optional[list[Endpoint]] = None + endpoints: list[Endpoint] | None = None def __init__(self, api_client: APIClientType): if self.app_name != api_client.app_name: @@ -51,8 +52,8 @@ def pre_request_hook(self, endpoint: Endpoint, *path_params, **params): def post_request_hook( self, endpoint: Endpoint, - response: Optional[RestResponse], - request_exception: Optional[RequestException], + response: RestResponse | None, + request_exception: RequestException | None, *path_params, **params, ): diff --git a/src/openapi_test_client/libraries/api/api_functions/decorators.py b/src/openapi_test_client/libraries/api/api_functions/decorators.py index 4341814..5824f17 100644 --- a/src/openapi_test_client/libraries/api/api_functions/decorators.py +++ b/src/openapi_test_client/libraries/api/api_functions/decorators.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Callable, ParamSpec +from typing import ParamSpec from common_libs.clients.rest_client import RestResponse from common_libs.logging import get_logger diff --git a/src/openapi_test_client/libraries/api/api_functions/endpoints.py b/src/openapi_test_client/libraries/api/api_functions/endpoints.py index 65a5646..eb527db 100644 --- a/src/openapi_test_client/libraries/api/api_functions/endpoints.py +++ b/src/openapi_test_client/libraries/api/api_functions/endpoints.py @@ -1,10 +1,11 @@ from __future__ import annotations +from collections.abc import Callable, Sequence from copy import deepcopy from dataclasses import dataclass from functools import partial, update_wrapper, wraps from threading import RLock -from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast from common_libs.ansi_colors import ColorCodes, color from common_libs.clients.rest_client import RestResponse @@ -49,8 +50,8 @@ class Endpoint: path: str func_name: str model: type[EndpointModel] - url: Optional[str] = None # Available only for an endpoint object accessed via an API client instance - content_type: Optional[str] = None + url: str | None = None # Available only for an endpoint object accessed via an API client instance + content_type: str | None = None is_public: bool = False is_documented: bool = True is_deprecated: bool = False @@ -409,7 +410,7 @@ def __init__( self.is_deprecated = False self.__decorators = [] - def __get__(self, instance: Optional[APIClassType], owner: type[APIClassType]) -> EndpointFunc: + def __get__(self, instance: APIClassType | None, owner: type[APIClassType]) -> EndpointFunc: """Return an EndpointFunc object""" key = (self.original_func.__name__, instance, owner) with EndpointHandler._lock: @@ -442,14 +443,14 @@ class EndpointFunc: All parameters passed to the original API class function call will be passed through to the __call__() """ - def __init__(self, endpoint_handler: EndpointHandler, instance: Optional[APIClassType], owner: type[APIClassType]): + def __init__(self, endpoint_handler: EndpointHandler, instance: APIClassType | None, owner: type[APIClassType]): """Initialize endpoint function""" if not issubclass(owner, APIBase): raise NotImplementedError(f"Unsupported API class: {owner}") self.method = endpoint_handler.method self.path = endpoint_handler.path - self.rest_client: Optional[RestClient] + self.rest_client: RestClient | None if instance: self.api_client = instance.api_client self.rest_client = self.api_client.rest_client @@ -643,7 +644,7 @@ def with_retry( f = retry_on(condition, num_retry=num_retry, retry_after=retry_after, safe_methods_only=False)(self) return f(*args, **kwargs) - def get_usage(self) -> Optional[str]: + def get_usage(self) -> str | None: """Get OpenAPI spec definition for the endpoint""" if self.api_client and self.endpoint.is_documented: return self.api_client.api_spec.get_endpoint_usage(self.endpoint) diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py index 6281ff1..cee010c 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py @@ -3,7 +3,7 @@ import json import re from collections import OrderedDict -from typing import TYPE_CHECKING, Annotated, Any, Optional, get_args, get_origin +from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin from common_libs.clients.rest_client.utils import get_supported_request_parameters from common_libs.logging import get_logger @@ -224,7 +224,7 @@ def generate_rest_func_params( # We will set the Content-type value using from the OpenAPI specs for this case, unless the header is explicitly # set by a user. Otherwise, requests lib will automatically handle this part if (data := rest_func_params.get("data")) and ( - isinstance(data, (str, bytes)) and not specified_content_type_header and endpoint.content_type + isinstance(data, str | bytes) and not specified_content_type_header and endpoint.content_type ): rest_func_params.setdefault("headers", {}).update({"Content-Type": endpoint.content_type}) @@ -233,7 +233,7 @@ def generate_rest_func_params( def _get_specified_content_type_header( requests_lib_options: dict[str, Any], session_headers: dict[str, str] -) -> Optional[str]: +) -> str | None: """Get Content-Type header value set for the request or for the current session""" request_headers = requests_lib_options.get("headers", {}) content_type_header = ( diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py index 79c5a01..3ce65ab 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py @@ -5,7 +5,7 @@ import re from copy import deepcopy from dataclasses import MISSING, Field, field, make_dataclass -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from common_libs.logging import get_logger @@ -131,7 +131,7 @@ def _parse_parameter_objects( method: str, parameter_objects: list[dict[str, Any]], path_param_fields: list[tuple[str, Any]], - body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]], + body_or_query_param_fields: list[tuple[str, Any, Field | None]], ): """Parse parameter objects @@ -205,8 +205,8 @@ def _parse_parameter_objects( def _parse_request_body_object( - request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]] -) -> Optional[str]: + request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field | None]] +) -> str | None: """Parse request body object https://swagger.io/specification/#request-body-object @@ -249,7 +249,7 @@ def parse_schema_obj(obj: dict[str, Any]): if _is_file_param(content_type, param_def): param_type = File if not param_def.is_required: - param_type = Optional[param_type] + param_type = param_type | None body_or_query_param_fields.append((param_name, param_type, field(default=None))) else: existing_param_names = [x[0] for x in body_or_query_param_fields] diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py index fba8eb6..c61d53c 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py @@ -48,7 +48,7 @@ def _is_param_model(obj: Any) -> bool: return _is_param_model(inner_type) -def get_param_model(annotated_type: Any) -> Optional[ParamModel | list[ParamModel]]: +def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None: """Returns a param model from the annotated type, if there is any :param annotated_type: Annotated type @@ -91,7 +91,7 @@ def get_reserved_model_names() -> list[str]: custom_param_annotation_names = [ x.__name__ for x in mod.__dict__.values() - if inspect.isclass(x) and issubclass(x, (ParamAnnotationType, DataclassModel)) + if inspect.isclass(x) and issubclass(x, ParamAnnotationType | DataclassModel) ] typing_class_names = [x.__name__ for x in [Any, Optional, Annotated, Literal, Union]] return custom_param_annotation_names + typing_class_names @@ -106,7 +106,7 @@ def create_model_from_param_def( :param model_name: The model name :param param_def: ParamDef generated from an OpenAPI parameter object """ - if not isinstance(param_def, (ParamDef, ParamDef.ParamGroup, ParamDef.UnknownType)): + if not isinstance(param_def, ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType): raise ValueError(f"Invalid param_def type: {type(param_def)}") if isinstance(param_def, ParamDef) and param_def.is_array and "items" in param_def: @@ -273,7 +273,7 @@ def visit(model_name: str): return sorted(models, key=lambda x: sorted_models_names.index(x.__name__)) -def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Optional[Field]]]): +def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field | None]]): """Clean illegal model field name and annotate the field type with Alias class :param param_fields: fields value to be passed to make_dataclass() diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py b/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py index 5bdd13c..38ed5cc 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py @@ -1,10 +1,11 @@ import inspect +from collections.abc import Sequence from dataclasses import asdict from functools import reduce from operator import or_ from types import NoneType, UnionType from typing import _AnnotatedAlias # noqa -from typing import Annotated, Any, Literal, Optional, Sequence, Union, get_args, get_origin +from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin from common_libs.logging import get_logger @@ -50,7 +51,7 @@ def get_type_annotation_as_str(tp: Any) -> str: return f"{tp.__origin__.__name__}[{inner_types}]" elif get_origin(tp) is Literal: return repr(tp).replace("typing.", "") - elif isinstance(tp, (Alias, Format)): + elif isinstance(tp, Alias | Format): return f"{type(tp).__name__}({repr(tp.value)})" elif isinstance(tp, Constraint): const = ", ".join( @@ -71,7 +72,7 @@ def get_type_annotation_as_str(tp: Any) -> str: def resolve_type_annotation( param_name: str, param_def: ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType, - _is_required: Optional[bool] = None, + _is_required: bool | None = None, _is_array: bool = False, ) -> Any: """Resolve type annotation for the given parameter definition @@ -134,7 +135,7 @@ def resolve(param_type: str, param_format: str = None): else: raise NotImplementedError(f"Unsupported type: {param_type}") - if not isinstance(param_def, (ParamDef, ParamDef.ParamGroup, ParamDef.UnknownType)): + if not isinstance(param_def, ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType): # for inner obj param_def = ParamDef.from_param_obj(param_def) @@ -236,7 +237,7 @@ def replace_inner_type(tp: Any, new_type: Any, replace_container_type: bool = Fa args = get_args(tp) if is_union_type(tp): if is_optional_type(tp): - return Optional[replace_inner_type(args[0], new_type)] + return Optional[replace_inner_type(args[0], new_type)] # noqa: UP007 else: return replace_inner_type(args, new_type) elif origin_type is Annotated: @@ -355,7 +356,7 @@ def generate_optional_type(tp: Any) -> Any: if is_optional_type(tp): return tp else: - return Union[tp, None] + return Union[tp, None] # noqa: UP007 def generate_annotated_type(tp: Any, metadata: Any): @@ -365,12 +366,12 @@ def generate_annotated_type(tp: Any, metadata: Any): """ if is_optional_type(tp): inner_type = get_args(tp)[0] - return Optional[Annotated[inner_type, metadata]] + return Optional[Annotated[inner_type, metadata]] # noqa: UP007 else: return Annotated[tp, metadata] -def get_annotated_type(tp: Any) -> Optional[_AnnotatedAlias]: +def get_annotated_type(tp: Any) -> _AnnotatedAlias | None: """Get annotated type definition :param tp: Type annotation diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py index fb3df6c..a4307b0 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py @@ -4,7 +4,7 @@ from datetime import date, datetime, time, timedelta from pathlib import Path from types import EllipsisType -from typing import Any, Optional, TypeVar, get_origin +from typing import Any, TypeVar, get_origin from uuid import UUID from pydantic import ( @@ -72,7 +72,7 @@ def in_validation_mode(): def generate_pydantic_model_fields( original_model: type[DataclassModel | EndpointModel | ParamModel], field_type: Any -) -> tuple[str, Optional[EllipsisType | FieldInfo]]: +) -> tuple[str, EllipsisType | FieldInfo | None]: """Generate Pydantic field definition for validation mode :param original_model: The original model @@ -134,7 +134,7 @@ def generate_pydantic_model_fields( if default_value is not None and constraint.nullable: # Required and nullable = Optional - field_type = Optional[field_type] + field_type = field_type | None # For query parameters,each parameter may be allowed to use multiple times with different values. Our client will # support this scenario by taking values as a list. To prevent a validation error to occur when giving a list, @@ -143,13 +143,13 @@ def generate_pydantic_model_fields( issubclass(original_model, EndpointModel) and original_model.endpoint_func.method.upper() == "GET" ): inner_type = param_type_util.get_inner_type(field_type) - if not get_origin(inner_type) is list: + if get_origin(inner_type) is not list: field_type = param_type_util.replace_inner_type(field_type, inner_type | list[inner_type]) return (field_type, field_value) -def filter_annotated_metadata(annotated_type: Any, target_class: type[T]) -> Optional[T]: +def filter_annotated_metadata(annotated_type: Any, target_class: type[T]) -> T | None: """Get a metadata for the target class from annotated type :param annotated_type: Type annotation with Annotated[] diff --git a/src/openapi_test_client/libraries/api/api_spec.py b/src/openapi_test_client/libraries/api/api_spec.py index 3fdd74a..137e449 100644 --- a/src/openapi_test_client/libraries/api/api_spec.py +++ b/src/openapi_test_client/libraries/api/api_spec.py @@ -4,7 +4,7 @@ import json import re from functools import lru_cache, reduce -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import requests import yaml @@ -29,7 +29,7 @@ def __init__(self, api_client: APIClientType, doc_path: str): self._spec = None @lru_cache - def get_api_spec(self, url: str = None) -> Optional[dict[str, Any]]: + def get_api_spec(self, url: str = None) -> dict[str, Any] | None: """Return OpenAPI spec""" if self._spec is None: if url: @@ -64,7 +64,7 @@ def get_api_spec(self, url: str = None) -> Optional[dict[str, Any]]: else: logger.warning("API spec is not available") - def get_endpoint_usage(self, endpoint: Endpoint) -> Optional[str]: + def get_endpoint_usage(self, endpoint: Endpoint) -> str | None: """Return usage of the endpoint :param endpoint: Endpoint object diff --git a/src/openapi_test_client/libraries/api/types.py b/src/openapi_test_client/libraries/api/types.py index a95c5ba..746f898 100644 --- a/src/openapi_test_client/libraries/api/types.py +++ b/src/openapi_test_client/libraries/api/types.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from collections.abc import Callable, Mapping, Sequence from dataclasses import _DataclassParams # noqa from dataclasses import MISSING, Field, asdict, astuple, dataclass, field, is_dataclass, make_dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from common_libs.decorators import freeze_args from common_libs.hash import HashableDict @@ -52,7 +53,7 @@ def type(self) -> str: return self["type"] @property - def format(self) -> Optional[str]: + def format(self) -> str | None: return self.get("format") @property @@ -104,7 +105,7 @@ def from_param_obj( """Convert the parameter object to a ParamDef""" def convert(obj: Any): - if isinstance(obj, (ParamDef, ParamDef.ParamGroup)): + if isinstance(obj, ParamDef | ParamDef.ParamGroup): return obj else: if "oneOf" in obj: @@ -192,7 +193,7 @@ def to_pydantic(cls) -> type[PydanticModel]: class EndpointModel(DataclassModel): - content_type: Optional[str] + content_type: str | None endpoint_func: EndpointFunc @@ -364,7 +365,7 @@ def setdefault(self, key: str, default: Any = None) -> Any: @classmethod def recreate( - cls, current_class: type[ParamModel], new_fields: list[tuple[str, Any, Optional[field]]] + cls, current_class: type[ParamModel], new_fields: list[tuple[str, Any, field | None]] ) -> type[ParamModel]: """Recreate the model with the new fields diff --git a/src/openapi_test_client/libraries/common/json_encoder.py b/src/openapi_test_client/libraries/common/json_encoder.py index f46eeea..cd788c3 100644 --- a/src/openapi_test_client/libraries/common/json_encoder.py +++ b/src/openapi_test_client/libraries/common/json_encoder.py @@ -9,7 +9,7 @@ class CustomJsonEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, (UUID, Decimal)): + if isinstance(obj, UUID | Decimal): return str(obj) elif isinstance(obj, datetime): return obj.isoformat()