Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding generics #343

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arango/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["ApiGroup"]

from typing import Callable, Optional, TypeVar
from typing import Any, Callable, Optional, TypeVar

from arango.connection import Connection
from arango.executor import ApiExecutor
13 changes: 8 additions & 5 deletions arango/aql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["AQL", "AQLQueryCache"]

from numbers import Number
from typing import MutableMapping, Optional, Sequence, Union
from typing import Generic, MutableMapping, Optional, Sequence, TypeVar, Union

from arango.api import ApiGroup
from arango.connection import Connection
@@ -145,14 +145,17 @@ def response_handler(resp: Response) -> bool:
return self._execute(request, response_handler)


class AQL(ApiGroup):
T = TypeVar("T")


class AQL(ApiGroup, Generic[T]):
"""AQL (ArangoDB Query Language) API wrapper.

:param connection: HTTP connection.
:param executor: API executor.
"""

def __init__(self, connection: Connection, executor: ApiExecutor) -> None:
def __init__(self, connection: Connection[T], executor: ApiExecutor) -> None:
super().__init__(connection, executor)

def __repr__(self) -> str:
@@ -173,7 +176,7 @@ def explain(
all_plans: bool = False,
max_plans: Optional[int] = None,
opt_rules: Optional[Sequence[str]] = None,
bind_vars: Optional[MutableMapping[str, DataTypes]] = None,
bind_vars: Optional[MutableMapping[str, Union[DataTypes, T]]] = None,
) -> Result[Union[Json, Jsons]]:
"""Inspect the query and return its metadata without executing it.

@@ -257,7 +260,7 @@ def execute(
count: bool = False,
batch_size: Optional[int] = None,
ttl: Optional[Number] = None,
bind_vars: Optional[MutableMapping[str, DataTypes]] = None,
bind_vars: Optional[MutableMapping[str, Union[DataTypes, T]]] = None,
full_count: Optional[bool] = None,
max_plans: Optional[int] = None,
optimizer_rules: Optional[Sequence[str]] = None,
42 changes: 9 additions & 33 deletions arango/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__all__ = ["ArangoClient"]

from json import dumps, loads
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Generic, Optional, Sequence, TypeVar, Union

import importlib_metadata

@@ -27,33 +26,13 @@
RoundRobinHostResolver,
SingleHostResolver,
)
from arango.serializer import Deserializer, JsonDeserializer, JsonSerializer, Serializer
from arango.typings import DataTypes


def default_serializer(x: Any) -> str:
"""
Default JSON serializer

:param x: A JSON data type object to serialize
:type x: Any
:return: The object serialized as a JSON string
:rtype: str
"""
return dumps(x, separators=(",", ":"))


def default_deserializer(x: str) -> Any:
"""
Default JSON de-serializer

:param x: A JSON string to deserialize
:type x: str
:return: The de-serialized JSON object
:rtype: Any
"""
return loads(x)
T = TypeVar("T")


class ArangoClient:
class ArangoClient(Generic[T]):
"""ArangoDB client.

:param hosts: Host URL or list of URLs (coordinators in a cluster).
@@ -104,8 +83,8 @@ def __init__(
host_resolver: Union[str, HostResolver] = "fallback",
resolver_max_tries: Optional[int] = None,
http_client: Optional[HTTPClient] = None,
serializer: Callable[..., str] = default_serializer,
deserializer: Callable[[str], Any] = default_deserializer,
serializer: Serializer[T] = JsonSerializer(),
deserializer: Deserializer[DataTypes] = JsonDeserializer(),
verify_override: Union[bool, str, None] = None,
request_timeout: Union[int, float, None] = DEFAULT_REQUEST_TIMEOUT,
request_compression: Optional[RequestCompression] = None,
@@ -199,8 +178,7 @@ def db(
auth_method: str = "basic",
user_token: Optional[str] = None,
superuser_token: Optional[str] = None,
verify_certificate: bool = True,
) -> StandardDatabase:
) -> StandardDatabase[T]:
"""Connect to an ArangoDB database and return the database API wrapper.

:param name: Database name.
@@ -228,14 +206,12 @@ def db(
are ignored. This token is not refreshed automatically. Token
expiry will not be checked.
:type superuser_token: str
:param verify_certificate: Verify TLS certificates.
:type verify_certificate: bool
:return: Standard database API wrapper.
:rtype: arango.database.StandardDatabase
:raise arango.exceptions.ServerConnectionError: If **verify** was set
to True and the connection fails.
"""
connection: Connection
connection: Connection[T]

if superuser_token is not None:
connection = JwtSuperuserConnection(
64 changes: 33 additions & 31 deletions arango/collection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["StandardCollection", "VertexCollection", "EdgeCollection"]

from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union
from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar, Union

from arango.api import ApiGroup
from arango.connection import Connection
@@ -56,8 +56,10 @@
is_none_or_str,
)

T = TypeVar("T", bound=dict[Any, Any])

class Collection(ApiGroup):

class Collection(ApiGroup, Generic[T]):
"""Base class for collection API wrappers.
:param connection: HTTP connection.
@@ -149,7 +151,7 @@ def _prep_from_body(self, document: Json, check_rev: bool) -> Tuple[str, Headers
return doc_id, {"If-Match": document["_rev"]}

def _prep_from_doc(
self, document: Union[str, Json], rev: Optional[str], check_rev: bool
self, document: Any, rev: Optional[str], check_rev: bool
) -> Tuple[str, Union[str, Json], Json]:
"""Prepare document ID, body and request headers.
@@ -199,7 +201,7 @@ def _ensure_key_in_body(self, body: Json) -> Json:
return body
raise DocumentParseError('field "_key" or "_id" required')

def _ensure_key_from_id(self, body: Json) -> Json:
def _ensure_key_from_id(self, body: T) -> T:
"""Return the body with "_key" field if it has "_id" field.
:param body: Document body.
@@ -600,7 +602,7 @@ def response_handler(resp: Response) -> int:

def has(
self,
document: Union[str, Json],
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
allow_dirty_read: bool = False,
@@ -1762,7 +1764,7 @@ def response_handler(resp: Response) -> bool:

def insert_many(
self,
documents: Sequence[Json],
documents: Sequence[T],
return_new: bool = False,
sync: Optional[bool] = None,
silent: bool = False,
@@ -1890,7 +1892,7 @@ def response_handler(

def update_many(
self,
documents: Sequence[Json],
documents: Sequence[T],
check_rev: bool = True,
merge: bool = True,
keep_none: bool = True,
@@ -2023,7 +2025,7 @@ def response_handler(
def update_match(
self,
filters: Json,
body: Json,
body: Union[Json, T],
limit: Optional[int] = None,
keep_none: bool = True,
sync: Optional[bool] = None,
@@ -2103,7 +2105,7 @@ def response_handler(resp: Response) -> int:

def replace_many(
self,
documents: Sequence[Json],
documents: Sequence[T],
check_rev: bool = True,
return_new: bool = False,
return_old: bool = False,
@@ -2217,7 +2219,7 @@ def response_handler(
def replace_match(
self,
filters: Json,
body: Json,
body: T,
limit: Optional[int] = None,
sync: Optional[bool] = None,
allow_dirty_read: bool = False,
@@ -2283,7 +2285,7 @@ def response_handler(resp: Response) -> int:

def delete_many(
self,
documents: Sequence[Json],
documents: Sequence[Union[T, Json]],
return_old: bool = False,
check_rev: bool = True,
sync: Optional[bool] = None,
@@ -2448,7 +2450,7 @@ def response_handler(resp: Response) -> int:

def import_bulk(
self,
documents: Sequence[Json],
documents: Sequence[T],
halt_on_error: bool = True,
details: bool = True,
from_prefix: Optional[str] = None,
@@ -2585,7 +2587,7 @@ def __getitem__(self, key: Union[str, Json]) -> Result[Optional[Json]]:

def get(
self,
document: Union[str, Json],
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
allow_dirty_read: bool = False,
@@ -2635,7 +2637,7 @@ def response_handler(resp: Response) -> Optional[Json]:

def insert(
self,
document: Json,
document: T,
return_new: bool = False,
sync: Optional[bool] = None,
silent: bool = False,
@@ -2739,7 +2741,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def update(
self,
document: Json,
document: Union[Json, T],
check_rev: bool = True,
merge: bool = True,
keep_none: bool = True,
@@ -2831,7 +2833,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def replace(
self,
document: Json,
document: T,
check_rev: bool = True,
return_new: bool = False,
return_old: bool = False,
@@ -2916,7 +2918,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def delete(
self,
document: Union[str, Json],
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
ignore_missing: bool = False,
@@ -2995,7 +2997,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:
return self._execute(request, response_handler)


class VertexCollection(Collection):
class VertexCollection(Collection[T]):
"""Vertex collection API wrapper.
:param connection: HTTP connection.
@@ -3005,7 +3007,7 @@ class VertexCollection(Collection):
"""

def __init__(
self, connection: Connection, executor: ApiExecutor, graph: str, name: str
self, connection: Connection[T], executor: ApiExecutor, graph: str, name: str
) -> None:
super().__init__(connection, executor, name)
self._graph = graph
@@ -3070,7 +3072,7 @@ def response_handler(resp: Response) -> Optional[Json]:

def insert(
self,
vertex: Json,
vertex: T,
sync: Optional[bool] = None,
silent: bool = False,
return_new: bool = False,
@@ -3119,7 +3121,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def update(
self,
vertex: Json,
vertex: Union[Json, T],
check_rev: bool = True,
keep_none: bool = True,
sync: Optional[bool] = None,
@@ -3189,7 +3191,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def replace(
self,
vertex: Json,
vertex: T,
check_rev: bool = True,
sync: Optional[bool] = None,
silent: bool = False,
@@ -3253,7 +3255,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def delete(
self,
vertex: Union[str, Json],
vertex: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
ignore_missing: bool = False,
@@ -3315,7 +3317,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:
return self._execute(request, response_handler)


class EdgeCollection(Collection):
class EdgeCollection(Collection[T]):
"""ArangoDB edge collection API wrapper.
:param connection: HTTP connection.
@@ -3388,7 +3390,7 @@ def response_handler(resp: Response) -> Optional[Json]:

def insert(
self,
edge: Json,
edge: T,
sync: Optional[bool] = None,
silent: bool = False,
return_new: bool = False,
@@ -3438,7 +3440,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def update(
self,
edge: Json,
edge: Union[Json, T],
check_rev: bool = True,
keep_none: bool = True,
sync: Optional[bool] = None,
@@ -3508,7 +3510,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def replace(
self,
edge: Json,
edge: T,
check_rev: bool = True,
sync: Optional[bool] = None,
silent: bool = False,
@@ -3573,7 +3575,7 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def delete(
self,
edge: Union[str, Json],
edge: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
ignore_missing: bool = False,
@@ -3635,8 +3637,8 @@ def response_handler(resp: Response) -> Union[bool, Json]:

def link(
self,
from_vertex: Union[str, Json],
to_vertex: Union[str, Json],
from_vertex: Union[str, Json, T],
to_vertex: Union[str, Json, T],
data: Optional[Json] = None,
sync: Optional[bool] = None,
silent: bool = False,
@@ -3672,7 +3674,7 @@ def link(

def edges(
self,
vertex: Union[str, Json],
vertex: Union[str, Json, T],
direction: Optional[str] = None,
allow_dirty_read: bool = False,
) -> Result[Json]:
47 changes: 29 additions & 18 deletions arango/connection.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@
import sys
import time
from abc import abstractmethod
from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
from json import dumps, loads
from typing import Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union

import jwt
from jwt.exceptions import ExpiredSignatureError
@@ -27,12 +28,15 @@
from arango.request import Request
from arango.resolver import HostResolver
from arango.response import Response
from arango.typings import Fields, Json
from arango.serializer import Deserializer, Serializer
from arango.typings import DataTypes, Fields, Json

Connection = Union["BasicConnection", "JwtConnection", "JwtSuperuserConnection"]

T = TypeVar("T")

class BaseConnection:

class BaseConnection(Generic[T]):
"""Base connection to a specific ArangoDB database."""

def __init__(
@@ -42,8 +46,8 @@ def __init__(
sessions: Sequence[Session],
db_name: str,
http_client: HTTPClient,
serializer: Callable[..., str],
deserializer: Callable[[str], Any],
serializer: Serializer[T],
deserializer: Deserializer[DataTypes],
request_compression: Optional[RequestCompression] = None,
response_compression: Optional[str] = None,
) -> None:
@@ -76,7 +80,7 @@ def username(self) -> Optional[str]:
"""
return self._username

def serialize(self, obj: Any) -> str:
def serialize(self, obj: Union[T, Sequence[T]]) -> str:
"""Serialize the given object.
:param obj: JSON object to serialize.
@@ -86,7 +90,7 @@ def serialize(self, obj: Any) -> str:
"""
return self._serializer(obj)

def deserialize(self, string: str) -> Any:
def deserialize(self, string: str) -> DataTypes:
"""De-serialize the string and return the object.
:param string: String to de-serialize.
@@ -194,15 +198,17 @@ def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Respo
headers=parent_response.headers,
status_code=parent_response.status_code,
status_text=parent_response.status_text,
raw_body=self.serialize(body),
raw_body=dumps(body, separators=(",", ":")),
)
resp.body = body
resp.error_code = body["errorNum"]
resp.error_message = body["errorMessage"]
resp.is_success = False
return resp

def normalize_data(self, data: Any) -> Union[str, MultipartEncoder, None]:
def normalize_data(
self, data: Union[str, MultipartEncoder, T, List[T], None]
) -> Union[str, MultipartEncoder, None]:
"""Normalize request data.
:param data: Request data.
@@ -247,7 +253,7 @@ def send_request(self, request: Request) -> Response: # pragma: no cover
raise NotImplementedError


class BasicConnection(BaseConnection):
class BasicConnection(BaseConnection[T]):
"""Connection to specific ArangoDB database using basic authentication.
:param hosts: Host URL or list of URLs (coordinators in a cluster).
@@ -279,8 +285,8 @@ def __init__(
username: str,
password: str,
http_client: HTTPClient,
serializer: Callable[..., str],
deserializer: Callable[[str], Any],
serializer: Serializer[T],
deserializer: Deserializer[DataTypes],
request_compression: Optional[RequestCompression] = None,
response_compression: Optional[str] = None,
) -> None:
@@ -310,7 +316,7 @@ def send_request(self, request: Request) -> Response:
return self.process_request(host_index, request, auth=self._auth)


class JwtConnection(BaseConnection):
class JwtConnection(BaseConnection[T]):
"""Connection to specific ArangoDB database using JWT authentication.
:param hosts: Host URL or list of URLs (coordinators in a cluster).
@@ -340,8 +346,8 @@ def __init__(
sessions: Sequence[Session],
db_name: str,
http_client: HTTPClient,
serializer: Callable[..., str],
deserializer: Callable[[str], Any],
serializer: Serializer[T],
deserializer: Deserializer[DataTypes],
username: Optional[str] = None,
password: Optional[str] = None,
user_token: Optional[str] = None,
@@ -420,6 +426,7 @@ def refresh_token(self) -> None:
method="post",
endpoint="/_open/auth",
data={"username": self._username, "password": self._password},
deserialize=False,
)

host_index = self._host_resolver.get_host_index()
@@ -429,6 +436,10 @@ def refresh_token(self) -> None:
if not resp.is_success:
raise JWTAuthError(resp, request)

resp.body = loads(resp.raw_body)
if "jwt" not in resp.body:
raise JWTRefreshError(f"JWT token not found in response body: {resp.body}")

self.set_token(resp.body["jwt"])

def set_token(self, token: str) -> None:
@@ -461,7 +472,7 @@ def set_token(self, token: str) -> None:
self._auth_header = f"bearer {self._token}"


class JwtSuperuserConnection(BaseConnection):
class JwtSuperuserConnection(BaseConnection[T]):
"""Connection to specific ArangoDB database using superuser JWT.
:param hosts: Host URL or list of URLs (coordinators in a cluster).
@@ -489,9 +500,9 @@ def __init__(
sessions: Sequence[Session],
db_name: str,
http_client: HTTPClient,
serializer: Callable[..., str],
deserializer: Callable[[str], Any],
superuser_token: str,
serializer: Serializer[T],
deserializer: Deserializer[DataTypes],
request_compression: Optional[RequestCompression] = None,
response_compression: Optional[str] = None,
) -> None:
2 changes: 1 addition & 1 deletion arango/cursor.py
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ class Cursor:

def __init__(
self,
connection: BaseConnection,
connection: BaseConnection[Any],
init_data: Json,
cursor_type: str = "cursor",
allow_retry: bool = False,
42 changes: 25 additions & 17 deletions arango/database.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from datetime import datetime
from numbers import Number
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union
from warnings import warn

from arango.api import ApiGroup
@@ -117,8 +117,10 @@
from arango.utils import get_col_name
from arango.wal import WAL

T = TypeVar("T")

class Database(ApiGroup):

class Database(ApiGroup, Generic[T]):
"""Base class for Database API wrappers."""

def __getitem__(self, name: str) -> StandardCollection:
@@ -152,7 +154,7 @@ def name(self) -> str:
return self.db_name

@property
def aql(self) -> AQL:
def aql(self) -> AQL[T]:
"""Return AQL (ArangoDB Query Language) API wrapper.
:return: AQL API wrapper.
@@ -1335,7 +1337,7 @@ def response_handler(resp: Response) -> bool:
# Collection Management #
#########################

def collection(self, name: str) -> StandardCollection:
def collection(self, name: str) -> StandardCollection[T]:
"""Return the standard collection API wrapper.
:param name: Collection name.
@@ -1408,7 +1410,7 @@ def create_collection(
write_concern: Optional[int] = None,
schema: Optional[Json] = None,
computedValues: Optional[Jsons] = None,
) -> Result[StandardCollection]:
) -> Result[StandardCollection[T]]:
"""Create a new collection.
:param name: Collection name.
@@ -1793,7 +1795,10 @@ def response_handler(resp: Response) -> bool:
#######################

def has_document(
self, document: Json, rev: Optional[str] = None, check_rev: bool = True
self,
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
) -> Result[bool]:
"""Check if a document exists.
@@ -1815,7 +1820,10 @@ def has_document(
)

def document(
self, document: Json, rev: Optional[str] = None, check_rev: bool = True
self,
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
) -> Result[Optional[Json]]:
"""Return a document.
@@ -1839,7 +1847,7 @@ def document(
def insert_document(
self,
collection: str,
document: Json,
document: T,
return_new: bool = False,
sync: Optional[bool] = None,
silent: bool = False,
@@ -1903,7 +1911,7 @@ def insert_document(

def update_document(
self,
document: Json,
document: Union[Json, T],
check_rev: bool = True,
merge: bool = True,
keep_none: bool = True,
@@ -1954,7 +1962,7 @@ def update_document(

def replace_document(
self,
document: Json,
document: T,
check_rev: bool = True,
return_new: bool = False,
return_old: bool = False,
@@ -1996,7 +2004,7 @@ def replace_document(

def delete_document(
self,
document: Union[str, Json],
document: Union[str, Json, T],
rev: Optional[str] = None,
check_rev: bool = True,
ignore_missing: bool = False,
@@ -2948,10 +2956,10 @@ def response_handler(resp: Response) -> Json:
return self._execute(request, response_handler)


class StandardDatabase(Database):
class StandardDatabase(Database[T]):
"""Standard database API wrapper."""

def __init__(self, connection: Connection) -> None:
def __init__(self, connection: Connection[T]) -> None:
super().__init__(connection=connection, executor=DefaultApiExecutor(connection))

def __repr__(self) -> str:
@@ -3071,7 +3079,7 @@ def begin_controlled_execution(
return OverloadControlDatabase(self._conn, max_queue_time_seconds)


class AsyncDatabase(Database):
class AsyncDatabase(Database[T]):
"""Database API wrapper tailored specifically for async execution.
See :func:`arango.database.StandardDatabase.begin_async_execution`.
@@ -3094,7 +3102,7 @@ def __repr__(self) -> str:
return f"<AsyncDatabase {self.name}>"


class BatchDatabase(Database):
class BatchDatabase(Database[T]):
"""Database API wrapper tailored specifically for batch execution.
.. note::
@@ -3163,7 +3171,7 @@ def commit(self) -> Optional[Sequence[BatchJob[Any]]]:
return self._executor.commit()


class TransactionDatabase(Database):
class TransactionDatabase(Database[T]):
"""Database API wrapper tailored specifically for transactions.
See :func:`arango.database.StandardDatabase.begin_transaction`.
@@ -3261,7 +3269,7 @@ def abort_transaction(self) -> bool:
return self._executor.abort()


class OverloadControlDatabase(Database):
class OverloadControlDatabase(Database[T]):
"""Database API wrapper tailored to gracefully handle server overload scenarios.
See :func:`arango.database.StandardDatabase.begin_controlled_execution`.
5 changes: 3 additions & 2 deletions arango/foxx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["Foxx"]

import os
from json import dumps
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union

from requests_toolbelt import MultipartEncoder
@@ -72,10 +73,10 @@ def _encode(
}

if config is not None:
fields["configuration"] = self._conn.serialize(config).encode("utf-8")
fields["configuration"] = dumps(config).encode("utf-8")

if dependencies is not None:
fields["dependencies"] = self._conn.serialize(dependencies).encode("utf-8")
fields["dependencies"] = dumps(dependencies).encode("utf-8")

return MultipartEncoder(fields=fields)

50 changes: 50 additions & 0 deletions arango/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
__all__ = [
"Serializer",
"Deserializer",
"JsonSerializer",
"JsonDeserializer",
]

from json import dumps, loads
from typing import Generic, Sequence, TypeVar, Union

from arango.typings import DataTypes, Json

T = TypeVar("T")


class Serializer(Generic[T]):
"""
Serializer interface.
For the use of bulk operations, it must also support List[T].
"""

def __call__(self, obj: Union[T, Sequence[T]]) -> str:
raise NotImplementedError


class Deserializer(Generic[T]):
"""
De-serializer interface
"""

def __call__(self, s: str) -> T:
raise NotImplementedError


class JsonSerializer(Serializer[Json]):
"""
Default JSON serializer
"""

def __call__(self, obj: Union[Json, Sequence[Json]]) -> str:
return dumps(obj, separators=(",", ":"))


class JsonDeserializer(Deserializer[DataTypes]):
"""
Default JSON de-serializer
"""

def __call__(self, s: str) -> DataTypes:
return loads(s)
4 changes: 2 additions & 2 deletions arango/utils.py
Original file line number Diff line number Diff line change
@@ -96,12 +96,12 @@ def is_none_or_bool(obj: Any) -> bool:
return obj is None or isinstance(obj, bool)


def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]:
def get_batches(elements: Sequence[Any], batch_size: int) -> Iterator[Sequence[Any]]:
"""Generator to split a list in batches
of (maximum) **batch_size** elements each.
:param elements: The list of elements.
:type elements: Sequence[Json]
:type elements: Sequence[Any]
:param batch_size: Max number of elements per batch.
:type batch_size: int
"""
7 changes: 2 additions & 5 deletions arango/wal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["WAL"]

from json import loads
from typing import Optional

from arango.api import ApiGroup
@@ -269,11 +270,7 @@ def response_handler(resp: Response) -> Json:
if resp.is_success:
result = format_replication_header(resp.headers)
result["content"] = (
[
self._conn.deserialize(line)
for line in resp.body.split("\n")
if line
]
[loads(line) for line in resp.body.split("\n") if line]
if deserialize
else resp.body
)