Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: object oriented sync and async clients #67

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration-tests/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

@pytest.fixture(scope="session")
def client() -> Client:
return Client("http://localhost:3000")
return Client(base_url="http://localhost:3000")
10 changes: 10 additions & 0 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ def _build_models(self) -> None:
models_init_template = self.env.get_template("models_init.py.jinja")
models_init.write_text(models_init_template.render(imports=imports, alls=alls), encoding=self.file_encoding)

# Generate wrapper
wrapper = self.package_dir / "wrapper.py"
wrapper_template = self.env.get_template("wrapper.py.jinja")
wrapper.write_text(
wrapper_template.render(
imports=imports,
alls=alls,
)
)

# pylint: disable=too-many-locals
def _build_api(self) -> None:
# Generate Client
Expand Down
14 changes: 12 additions & 2 deletions openapi_python_client/templates/client.py.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ssl
from typing import Dict, Union
import os
from typing import Dict, Union, Optional

import attr

@attr.s(auto_attribs=True)
Expand All @@ -19,14 +21,22 @@ class Client:
follow_redirects: Whether or not to follow redirects. Default value is False.
"""

base_url: str
base_url: Optional[str] = attr.ib(None, kw_only=True)
cookies: Dict[str, str] = attr.ib(factory=dict, kw_only=True)
headers: Dict[str, str] = attr.ib(factory=dict, kw_only=True)
timeout: float = attr.ib(5.0, kw_only=True)
verify_ssl: Union[str, bool, ssl.SSLContext] = attr.ib(True, kw_only=True)
raise_on_unexpected_status: bool = attr.ib(False, kw_only=True)
follow_redirects: bool = attr.ib(False, kw_only=True)

def __attrs_post_init__(self) -> None:
env_base_url = os.environ.get('{{ openapi.title | snakecase | upper }}_BASE_URL')
self.base_url = self.base_url or env_base_url
if self.base_url is None:
raise ValueError(f'"base_url" has to be set either from the '
f'environment variable "{env_base_url}", or '
f'passed with the "base_url" argument')

def get_headers(self) -> Dict[str, str]:
""" Get headers to be used in all endpoints """
return {**self.headers}
Expand Down
1 change: 1 addition & 0 deletions openapi_python_client/templates/endpoint_init.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import {% for e in endpoint_collection.endpoints %} {{e.name | snakecase }}, {% endfor %}
6 changes: 5 additions & 1 deletion openapi_python_client/templates/endpoint_macros.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,22 @@ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
{% endmacro %}

{# The all the kwargs passed into an endpoint (and variants thereof)) #}
{% macro arguments(endpoint) %}
{% macro arguments(endpoint, client=True) %}
{# path parameters #}
{% for parameter in endpoint.path_parameters.values() %}
{{ parameter.to_string() }},
{% endfor %}
{% if endpoint.form_body_reference or endpoint.multipart_body_reference or endpoint.query_parameters or endpoint.json_body or endpoint.header_parameters or endpoint.cookie_parameters %}
*,
{% endif %}
{# Proper client based on whether or not the endpoint requires authentication #}
{% if client %}
{% if endpoint.requires_security %}
client: AuthenticatedClient,
{% else %}
client: Client,
{% endif %}
{% endif %}
{# Form data if any #}
{% if endpoint.form_body %}
form_data: {{ endpoint.form_body.get_type_string() }},
Expand Down
6 changes: 3 additions & 3 deletions openapi_python_client/templates/package_init.py.jinja
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{% from "helpers.jinja" import safe_docstring %}

{{ safe_docstring(package_description) }}
from .client import AuthenticatedClient, Client
from .wrapper import Sync{{ openapi.title | pascalcase }}Client, {{ openapi.title | pascalcase }}Client

__all__ = (
"AuthenticatedClient",
"Client",
"Sync{{ openapi.title | pascalcase }}Client",
"{{ openapi.title | pascalcase }}Client"
)
2 changes: 1 addition & 1 deletion openapi_python_client/templates/pyproject.toml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include = ["CHANGELOG.md", "{{ package_name }}/py.typed"]

[tool.poetry.dependencies]
python = "^3.8"
httpx = ">=0.15.4,<0.25.0"
httpx = ">=0.15.4"
attrs = ">=21.3.0"
python-dateutil = "^2.8.0"

Expand Down
2 changes: 1 addition & 1 deletion openapi_python_client/templates/setup.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ setup(
long_description_content_type="text/markdown",
packages=find_packages(),
python_requires=">=3.8, <4",
install_requires=["httpx >= 0.15.0, < 0.25.0", "attrs >= 21.3.0", "python-dateutil >= 2.8.0, < 3"],
install_requires=["httpx >= 0.15.0", "attrs >= 21.3.0", "python-dateutil >= 2.8.0, < 3"],
package_data={"{{ package_name }}": ["py.typed"]},
)
141 changes: 141 additions & 0 deletions openapi_python_client/templates/wrapper.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import datetime
from dateutil.parser import isoparse
from typing import Any, Dict, Optional, Union, cast, List
from .client import Client as InnerClient, AuthenticatedClient
from .types import UNSET, File, Response, Unset


from .models import (
{% for all in alls | sort %}
{{ all }},
{% endfor %}
)

from .api import (
{% for tag, collection in endpoint_collections_by_tag.items() %}
{{ tag | snakecase }},
{% endfor %}
)

{% from "endpoint_macros.py.jinja" import arguments, client, kwargs %}


{% for tag, collection in endpoint_collections_by_tag.items() %}

class {{ tag | pascalcase }}Api:

def __init__(self, client: InnerClient):
self._client = client

{% for endpoint in collection.endpoints %}
{% set return_string = endpoint.response_type() %}
{% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %}

{% if parsed_responses %}
async def {{ endpoint.name | snakecase }}(
self,
{{ arguments(endpoint, False) | indent(8) }}
) -> Optional[{{ return_string | indent(4) }}]:
{% if endpoint.requires_security %}
client = cast(AuthenticatedClient, self._client)
{% else %}
client = self._client
{% endif %}
return await {{ tag }}.{{ endpoint.name | snakecase }}.asyncio(
{{ kwargs(endpoint) | indent(12) }}
)
{% endif %}

async def {{ endpoint.name | snakecase }}_detailed(
self,
{{ arguments(endpoint, False) | indent(8) }}
) -> Response[{{ return_string | indent(4) }}]:
{% if endpoint.requires_security %}
client = cast(AuthenticatedClient, self._client)
{% else %}
client = self._client
{% endif %}
return await {{ tag }}.{{ endpoint.name | snakecase }}.asyncio_detailed(
{{ kwargs(endpoint) | indent(12) }}
)

{% endfor %}


class Sync{{ tag | pascalcase }}Api:

def __init__(self, client: InnerClient):
self._client = client

{% for endpoint in collection.endpoints %}
{% set return_string = endpoint.response_type() %}
{% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %}

{% if parsed_responses %}
def {{ endpoint.name | snakecase }}(
self,
{{ arguments(endpoint, False) | indent(8) }}
) -> Optional[{{ return_string | indent(4) }}]:
{% if endpoint.requires_security %}
client = cast(AuthenticatedClient, self._client)
{% else %}
client = self._client
{% endif %}
return {{ tag }}.{{ endpoint.name | snakecase }}.sync(
{{ kwargs(endpoint) | indent(12) }}
)
{% endif %}

def {{ endpoint.name | snakecase }}_detailed(
self,
{{ arguments(endpoint, False) | indent(8) }}
) -> Response[{{ return_string | indent(4) }}]:
{% if endpoint.requires_security %}
client = cast(AuthenticatedClient, self._client)
{% else %}
client = self._client
{% endif %}
return {{ tag }}.{{ endpoint.name | snakecase }}.sync_detailed(
{{ kwargs(endpoint) | indent(12) }}
)

{% endfor %}

{% endfor %}

{% for prefix in '', 'Sync' %}

class {{ prefix }}{{ openapi.title | pascalcase }}Client:
def __init__(
self,
base_url: Optional[str] = None,
timeout: float = 5.0,
token: Optional[str] = None,
header_prefix: str = "Bearer",
header_name: str = "Authorization",
verify_ssl: bool = True,
):
if token is None:
self.connection = InnerClient(
base_url=base_url,
timeout=timeout,
verify_ssl=verify_ssl,
raise_on_unexpected_status=True,
follow_redirects=True,
)
else:
self.connection = AuthenticatedClient(
base_url=base_url,
timeout=timeout,
token=token,
prefix=header_prefix,
auth_header_name=header_name,
verify_ssl=verify_ssl,
raise_on_unexpected_status=True,
follow_redirects=True,
)
{% for tag, collection in endpoint_collections_by_tag.items() %}
self.{{ tag | snakecase }} = {{ prefix }}{{ tag | pascalcase }}Api(self.connection)
{% endfor %}

{% endfor %}