diff --git a/integration-tests/tests/conftest.py b/integration-tests/tests/conftest.py index 9cdd679fa..8530594e3 100644 --- a/integration-tests/tests/conftest.py +++ b/integration-tests/tests/conftest.py @@ -5,4 +5,4 @@ @pytest.fixture(scope="session") def client() -> Client: - return Client("http://localhost:3000") + return Client(base_url="http://localhost:3000") diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index e7bd51475..cfc4b715c 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -253,6 +253,16 @@ def _build_models(self) -> None: models_init_template = self.env.get_template("models_init.py.jinja") models_init.write_text(models_init_template.render(imports=imports, alls=alls), encoding=self.file_encoding) + # Generate wrapper + wrapper = self.package_dir / "wrapper.py" + wrapper_template = self.env.get_template("wrapper.py.jinja") + wrapper.write_text( + wrapper_template.render( + imports=imports, + alls=alls, + ) + ) + # pylint: disable=too-many-locals def _build_api(self) -> None: # Generate Client diff --git a/openapi_python_client/templates/client.py.jinja b/openapi_python_client/templates/client.py.jinja index c6e6b2305..77d958b95 100644 --- a/openapi_python_client/templates/client.py.jinja +++ b/openapi_python_client/templates/client.py.jinja @@ -1,5 +1,7 @@ import ssl -from typing import Dict, Union +import os +from typing import Dict, Union, Optional + import attr @attr.s(auto_attribs=True) @@ -19,7 +21,7 @@ class Client: follow_redirects: Whether or not to follow redirects. Default value is False. """ - base_url: str + base_url: Optional[str] = attr.ib(None, kw_only=True) cookies: Dict[str, str] = attr.ib(factory=dict, kw_only=True) headers: Dict[str, str] = attr.ib(factory=dict, kw_only=True) timeout: float = attr.ib(5.0, kw_only=True) @@ -27,6 +29,14 @@ class Client: raise_on_unexpected_status: bool = attr.ib(False, kw_only=True) follow_redirects: bool = attr.ib(False, kw_only=True) + def __attrs_post_init__(self) -> None: + env_base_url = os.environ.get('{{ openapi.title | snakecase | upper }}_BASE_URL') + self.base_url = self.base_url or env_base_url + if self.base_url is None: + raise ValueError(f'"base_url" has to be set either from the ' + f'environment variable "{env_base_url}", or ' + f'passed with the "base_url" argument') + def get_headers(self) -> Dict[str, str]: """ Get headers to be used in all endpoints """ return {**self.headers} diff --git a/openapi_python_client/templates/endpoint_init.py.jinja b/openapi_python_client/templates/endpoint_init.py.jinja index e69de29bb..234f7abdf 100644 --- a/openapi_python_client/templates/endpoint_init.py.jinja +++ b/openapi_python_client/templates/endpoint_init.py.jinja @@ -0,0 +1 @@ +from . import {% for e in endpoint_collection.endpoints %} {{e.name | snakecase }}, {% endfor %} diff --git a/openapi_python_client/templates/endpoint_macros.py.jinja b/openapi_python_client/templates/endpoint_macros.py.jinja index 090796537..b5f44220a 100644 --- a/openapi_python_client/templates/endpoint_macros.py.jinja +++ b/openapi_python_client/templates/endpoint_macros.py.jinja @@ -78,18 +78,22 @@ params = {k: v for k, v in params.items() if v is not UNSET and v is not None} {% endmacro %} {# The all the kwargs passed into an endpoint (and variants thereof)) #} -{% macro arguments(endpoint) %} +{% macro arguments(endpoint, client=True) %} {# path parameters #} {% for parameter in endpoint.path_parameters.values() %} {{ parameter.to_string() }}, {% endfor %} +{% if endpoint.form_body_reference or endpoint.multipart_body_reference or endpoint.query_parameters or endpoint.json_body or endpoint.header_parameters or endpoint.cookie_parameters %} *, +{% endif %} {# Proper client based on whether or not the endpoint requires authentication #} +{% if client %} {% if endpoint.requires_security %} client: AuthenticatedClient, {% else %} client: Client, {% endif %} +{% endif %} {# Form data if any #} {% if endpoint.form_body %} form_data: {{ endpoint.form_body.get_type_string() }}, diff --git a/openapi_python_client/templates/package_init.py.jinja b/openapi_python_client/templates/package_init.py.jinja index ecf60e74d..1e5ca0e7e 100644 --- a/openapi_python_client/templates/package_init.py.jinja +++ b/openapi_python_client/templates/package_init.py.jinja @@ -1,9 +1,9 @@ {% from "helpers.jinja" import safe_docstring %} {{ safe_docstring(package_description) }} -from .client import AuthenticatedClient, Client +from .wrapper import Sync{{ openapi.title | pascalcase }}Client, {{ openapi.title | pascalcase }}Client __all__ = ( - "AuthenticatedClient", - "Client", + "Sync{{ openapi.title | pascalcase }}Client", + "{{ openapi.title | pascalcase }}Client" ) diff --git a/openapi_python_client/templates/pyproject.toml.jinja b/openapi_python_client/templates/pyproject.toml.jinja index e3ed7b57e..c4e803926 100644 --- a/openapi_python_client/templates/pyproject.toml.jinja +++ b/openapi_python_client/templates/pyproject.toml.jinja @@ -14,7 +14,7 @@ include = ["CHANGELOG.md", "{{ package_name }}/py.typed"] [tool.poetry.dependencies] python = "^3.8" -httpx = ">=0.15.4,<0.25.0" +httpx = ">=0.15.4" attrs = ">=21.3.0" python-dateutil = "^2.8.0" diff --git a/openapi_python_client/templates/setup.py.jinja b/openapi_python_client/templates/setup.py.jinja index c2bc949d4..288596466 100644 --- a/openapi_python_client/templates/setup.py.jinja +++ b/openapi_python_client/templates/setup.py.jinja @@ -13,6 +13,6 @@ setup( long_description_content_type="text/markdown", packages=find_packages(), python_requires=">=3.8, <4", - install_requires=["httpx >= 0.15.0, < 0.25.0", "attrs >= 21.3.0", "python-dateutil >= 2.8.0, < 3"], + install_requires=["httpx >= 0.15.0", "attrs >= 21.3.0", "python-dateutil >= 2.8.0, < 3"], package_data={"{{ package_name }}": ["py.typed"]}, ) diff --git a/openapi_python_client/templates/wrapper.py.jinja b/openapi_python_client/templates/wrapper.py.jinja new file mode 100644 index 000000000..775566ef0 --- /dev/null +++ b/openapi_python_client/templates/wrapper.py.jinja @@ -0,0 +1,141 @@ +import datetime +from dateutil.parser import isoparse +from typing import Any, Dict, Optional, Union, cast, List +from .client import Client as InnerClient, AuthenticatedClient +from .types import UNSET, File, Response, Unset + + +from .models import ( + {% for all in alls | sort %} + {{ all }}, + {% endfor %} +) + +from .api import ( + {% for tag, collection in endpoint_collections_by_tag.items() %} + {{ tag | snakecase }}, + {% endfor %} +) + +{% from "endpoint_macros.py.jinja" import arguments, client, kwargs %} + + +{% for tag, collection in endpoint_collections_by_tag.items() %} + +class {{ tag | pascalcase }}Api: + + def __init__(self, client: InnerClient): + self._client = client + + {% for endpoint in collection.endpoints %} + {% set return_string = endpoint.response_type() %} + {% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %} + + {% if parsed_responses %} + async def {{ endpoint.name | snakecase }}( + self, + {{ arguments(endpoint, False) | indent(8) }} + ) -> Optional[{{ return_string | indent(4) }}]: + {% if endpoint.requires_security %} + client = cast(AuthenticatedClient, self._client) + {% else %} + client = self._client + {% endif %} + return await {{ tag }}.{{ endpoint.name | snakecase }}.asyncio( + {{ kwargs(endpoint) | indent(12) }} + ) + {% endif %} + + async def {{ endpoint.name | snakecase }}_detailed( + self, + {{ arguments(endpoint, False) | indent(8) }} + ) -> Response[{{ return_string | indent(4) }}]: + {% if endpoint.requires_security %} + client = cast(AuthenticatedClient, self._client) + {% else %} + client = self._client + {% endif %} + return await {{ tag }}.{{ endpoint.name | snakecase }}.asyncio_detailed( + {{ kwargs(endpoint) | indent(12) }} + ) + + {% endfor %} + + +class Sync{{ tag | pascalcase }}Api: + + def __init__(self, client: InnerClient): + self._client = client + + {% for endpoint in collection.endpoints %} + {% set return_string = endpoint.response_type() %} + {% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %} + + {% if parsed_responses %} + def {{ endpoint.name | snakecase }}( + self, + {{ arguments(endpoint, False) | indent(8) }} + ) -> Optional[{{ return_string | indent(4) }}]: + {% if endpoint.requires_security %} + client = cast(AuthenticatedClient, self._client) + {% else %} + client = self._client + {% endif %} + return {{ tag }}.{{ endpoint.name | snakecase }}.sync( + {{ kwargs(endpoint) | indent(12) }} + ) + {% endif %} + + def {{ endpoint.name | snakecase }}_detailed( + self, + {{ arguments(endpoint, False) | indent(8) }} + ) -> Response[{{ return_string | indent(4) }}]: + {% if endpoint.requires_security %} + client = cast(AuthenticatedClient, self._client) + {% else %} + client = self._client + {% endif %} + return {{ tag }}.{{ endpoint.name | snakecase }}.sync_detailed( + {{ kwargs(endpoint) | indent(12) }} + ) + + {% endfor %} + +{% endfor %} + +{% for prefix in '', 'Sync' %} + +class {{ prefix }}{{ openapi.title | pascalcase }}Client: + def __init__( + self, + base_url: Optional[str] = None, + timeout: float = 5.0, + token: Optional[str] = None, + header_prefix: str = "Bearer", + header_name: str = "Authorization", + verify_ssl: bool = True, + ): + if token is None: + self.connection = InnerClient( + base_url=base_url, + timeout=timeout, + verify_ssl=verify_ssl, + raise_on_unexpected_status=True, + follow_redirects=True, + ) + else: + self.connection = AuthenticatedClient( + base_url=base_url, + timeout=timeout, + token=token, + prefix=header_prefix, + auth_header_name=header_name, + verify_ssl=verify_ssl, + raise_on_unexpected_status=True, + follow_redirects=True, + ) + {% for tag, collection in endpoint_collections_by_tag.items() %} + self.{{ tag | snakecase }} = {{ prefix }}{{ tag | pascalcase }}Api(self.connection) + {% endfor %} + +{% endfor %}