From c694ffe82c74faf3ddaa8cdd483d7e5241031cd4 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Sun, 19 May 2024 20:36:57 +0300 Subject: [PATCH 1/3] Adding serializers file --- arango/serializer.py | 47 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 arango/serializer.py diff --git a/arango/serializer.py b/arango/serializer.py new file mode 100644 index 00000000..ea72ccc7 --- /dev/null +++ b/arango/serializer.py @@ -0,0 +1,47 @@ +__all__ = [ + "Serializer", + "Deserializer", + "JsonSerializer", + "JsonDeserializer", +] + +from json import dumps, loads +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class Serializer(Generic[T]): + """ + Serializer interface + """ + + def __call__(self, data: T) -> str: + raise NotImplementedError + + +class Deserializer: + """ + De-serializer interface + """ + + def __call__(self, data: str) -> Any: + raise NotImplementedError + + +class JsonSerializer(Serializer[Any]): + """ + Default JSON serializer + """ + + def __call__(self, data: Any) -> str: + return dumps(data, separators=(",", ":")) + + +class JsonDeserializer(Deserializer): + """ + Default JSON de-serializer + """ + + def __call__(self, data: str) -> Any: + return loads(data) From 6d6f17c0d6eac874b0a96da86246c0a598621632 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Mon, 20 May 2024 15:11:46 +0300 Subject: [PATCH 2/3] Decoupling config response types from default deserializer --- arango/foxx.py | 5 +++-- arango/wal.py | 7 ++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/arango/foxx.py b/arango/foxx.py index 41f61f52..7b02ae3f 100644 --- a/arango/foxx.py +++ b/arango/foxx.py @@ -1,6 +1,7 @@ __all__ = ["Foxx"] import os +from json import dumps from typing import Any, BinaryIO, Dict, Optional, Tuple, Union from requests_toolbelt import MultipartEncoder @@ -72,10 +73,10 @@ def _encode( } if config is not None: - fields["configuration"] = self._conn.serialize(config).encode("utf-8") + fields["configuration"] = dumps(config).encode("utf-8") if dependencies is not None: - fields["dependencies"] = self._conn.serialize(dependencies).encode("utf-8") + fields["dependencies"] = dumps(dependencies).encode("utf-8") return MultipartEncoder(fields=fields) diff --git a/arango/wal.py b/arango/wal.py index 9ad089f5..0bac9cd0 100644 --- a/arango/wal.py +++ b/arango/wal.py @@ -1,5 +1,6 @@ __all__ = ["WAL"] +from json import loads from typing import Optional from arango.api import ApiGroup @@ -269,11 +270,7 @@ def response_handler(resp: Response) -> Json: if resp.is_success: result = format_replication_header(resp.headers) result["content"] = ( - [ - self._conn.deserialize(line) - for line in resp.body.split("\n") - if line - ] + [loads(line) for line in resp.body.split("\n") if line] if deserialize else resp.body ) From 65b6b9e39ca5046041e6e99851b931fde8ccab29 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Mon, 20 May 2024 17:47:42 +0300 Subject: [PATCH 3/3] Supporting typing when a custom serializer is being used --- arango/api.py | 2 +- arango/aql.py | 13 +++++---- arango/client.py | 42 +++++++---------------------- arango/collection.py | 64 +++++++++++++++++++++++--------------------- arango/connection.py | 47 +++++++++++++++++++------------- arango/cursor.py | 2 +- arango/database.py | 42 +++++++++++++++++------------ arango/serializer.py | 25 +++++++++-------- arango/utils.py | 4 +-- 9 files changed, 122 insertions(+), 119 deletions(-) diff --git a/arango/api.py b/arango/api.py index c57f4c1c..bf8e113b 100644 --- a/arango/api.py +++ b/arango/api.py @@ -1,6 +1,6 @@ __all__ = ["ApiGroup"] -from typing import Callable, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar from arango.connection import Connection from arango.executor import ApiExecutor diff --git a/arango/aql.py b/arango/aql.py index 941000e5..0ab4f6b2 100644 --- a/arango/aql.py +++ b/arango/aql.py @@ -1,7 +1,7 @@ __all__ = ["AQL", "AQLQueryCache"] from numbers import Number -from typing import MutableMapping, Optional, Sequence, Union +from typing import Generic, MutableMapping, Optional, Sequence, TypeVar, Union from arango.api import ApiGroup from arango.connection import Connection @@ -145,14 +145,17 @@ def response_handler(resp: Response) -> bool: return self._execute(request, response_handler) -class AQL(ApiGroup): +T = TypeVar("T") + + +class AQL(ApiGroup, Generic[T]): """AQL (ArangoDB Query Language) API wrapper. :param connection: HTTP connection. :param executor: API executor. """ - def __init__(self, connection: Connection, executor: ApiExecutor) -> None: + def __init__(self, connection: Connection[T], executor: ApiExecutor) -> None: super().__init__(connection, executor) def __repr__(self) -> str: @@ -173,7 +176,7 @@ def explain( all_plans: bool = False, max_plans: Optional[int] = None, opt_rules: Optional[Sequence[str]] = None, - bind_vars: Optional[MutableMapping[str, DataTypes]] = None, + bind_vars: Optional[MutableMapping[str, Union[DataTypes, T]]] = None, ) -> Result[Union[Json, Jsons]]: """Inspect the query and return its metadata without executing it. @@ -257,7 +260,7 @@ def execute( count: bool = False, batch_size: Optional[int] = None, ttl: Optional[Number] = None, - bind_vars: Optional[MutableMapping[str, DataTypes]] = None, + bind_vars: Optional[MutableMapping[str, Union[DataTypes, T]]] = None, full_count: Optional[bool] = None, max_plans: Optional[int] = None, optimizer_rules: Optional[Sequence[str]] = None, diff --git a/arango/client.py b/arango/client.py index 0ccaa19f..79ca1764 100644 --- a/arango/client.py +++ b/arango/client.py @@ -1,7 +1,6 @@ __all__ = ["ArangoClient"] -from json import dumps, loads -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Generic, Optional, Sequence, TypeVar, Union import importlib_metadata @@ -27,33 +26,13 @@ RoundRobinHostResolver, SingleHostResolver, ) +from arango.serializer import Deserializer, JsonDeserializer, JsonSerializer, Serializer +from arango.typings import DataTypes - -def default_serializer(x: Any) -> str: - """ - Default JSON serializer - - :param x: A JSON data type object to serialize - :type x: Any - :return: The object serialized as a JSON string - :rtype: str - """ - return dumps(x, separators=(",", ":")) - - -def default_deserializer(x: str) -> Any: - """ - Default JSON de-serializer - - :param x: A JSON string to deserialize - :type x: str - :return: The de-serialized JSON object - :rtype: Any - """ - return loads(x) +T = TypeVar("T") -class ArangoClient: +class ArangoClient(Generic[T]): """ArangoDB client. :param hosts: Host URL or list of URLs (coordinators in a cluster). @@ -104,8 +83,8 @@ def __init__( host_resolver: Union[str, HostResolver] = "fallback", resolver_max_tries: Optional[int] = None, http_client: Optional[HTTPClient] = None, - serializer: Callable[..., str] = default_serializer, - deserializer: Callable[[str], Any] = default_deserializer, + serializer: Serializer[T] = JsonSerializer(), + deserializer: Deserializer[DataTypes] = JsonDeserializer(), verify_override: Union[bool, str, None] = None, request_timeout: Union[int, float, None] = DEFAULT_REQUEST_TIMEOUT, request_compression: Optional[RequestCompression] = None, @@ -199,8 +178,7 @@ def db( auth_method: str = "basic", user_token: Optional[str] = None, superuser_token: Optional[str] = None, - verify_certificate: bool = True, - ) -> StandardDatabase: + ) -> StandardDatabase[T]: """Connect to an ArangoDB database and return the database API wrapper. :param name: Database name. @@ -228,14 +206,12 @@ def db( are ignored. This token is not refreshed automatically. Token expiry will not be checked. :type superuser_token: str - :param verify_certificate: Verify TLS certificates. - :type verify_certificate: bool :return: Standard database API wrapper. :rtype: arango.database.StandardDatabase :raise arango.exceptions.ServerConnectionError: If **verify** was set to True and the connection fails. """ - connection: Connection + connection: Connection[T] if superuser_token is not None: connection = JwtSuperuserConnection( diff --git a/arango/collection.py b/arango/collection.py index 820c5200..9f5278b3 100644 --- a/arango/collection.py +++ b/arango/collection.py @@ -1,7 +1,7 @@ __all__ = ["StandardCollection", "VertexCollection", "EdgeCollection"] from numbers import Number -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar, Union from arango.api import ApiGroup from arango.connection import Connection @@ -56,8 +56,10 @@ is_none_or_str, ) +T = TypeVar("T", bound=dict[Any, Any]) -class Collection(ApiGroup): + +class Collection(ApiGroup, Generic[T]): """Base class for collection API wrappers. :param connection: HTTP connection. @@ -149,7 +151,7 @@ def _prep_from_body(self, document: Json, check_rev: bool) -> Tuple[str, Headers return doc_id, {"If-Match": document["_rev"]} def _prep_from_doc( - self, document: Union[str, Json], rev: Optional[str], check_rev: bool + self, document: Any, rev: Optional[str], check_rev: bool ) -> Tuple[str, Union[str, Json], Json]: """Prepare document ID, body and request headers. @@ -199,7 +201,7 @@ def _ensure_key_in_body(self, body: Json) -> Json: return body raise DocumentParseError('field "_key" or "_id" required') - def _ensure_key_from_id(self, body: Json) -> Json: + def _ensure_key_from_id(self, body: T) -> T: """Return the body with "_key" field if it has "_id" field. :param body: Document body. @@ -600,7 +602,7 @@ def response_handler(resp: Response) -> int: def has( self, - document: Union[str, Json], + document: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, allow_dirty_read: bool = False, @@ -1762,7 +1764,7 @@ def response_handler(resp: Response) -> bool: def insert_many( self, - documents: Sequence[Json], + documents: Sequence[T], return_new: bool = False, sync: Optional[bool] = None, silent: bool = False, @@ -1890,7 +1892,7 @@ def response_handler( def update_many( self, - documents: Sequence[Json], + documents: Sequence[T], check_rev: bool = True, merge: bool = True, keep_none: bool = True, @@ -2023,7 +2025,7 @@ def response_handler( def update_match( self, filters: Json, - body: Json, + body: Union[Json, T], limit: Optional[int] = None, keep_none: bool = True, sync: Optional[bool] = None, @@ -2103,7 +2105,7 @@ def response_handler(resp: Response) -> int: def replace_many( self, - documents: Sequence[Json], + documents: Sequence[T], check_rev: bool = True, return_new: bool = False, return_old: bool = False, @@ -2217,7 +2219,7 @@ def response_handler( def replace_match( self, filters: Json, - body: Json, + body: T, limit: Optional[int] = None, sync: Optional[bool] = None, allow_dirty_read: bool = False, @@ -2283,7 +2285,7 @@ def response_handler(resp: Response) -> int: def delete_many( self, - documents: Sequence[Json], + documents: Sequence[Union[T, Json]], return_old: bool = False, check_rev: bool = True, sync: Optional[bool] = None, @@ -2448,7 +2450,7 @@ def response_handler(resp: Response) -> int: def import_bulk( self, - documents: Sequence[Json], + documents: Sequence[T], halt_on_error: bool = True, details: bool = True, from_prefix: Optional[str] = None, @@ -2585,7 +2587,7 @@ def __getitem__(self, key: Union[str, Json]) -> Result[Optional[Json]]: def get( self, - document: Union[str, Json], + document: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, allow_dirty_read: bool = False, @@ -2635,7 +2637,7 @@ def response_handler(resp: Response) -> Optional[Json]: def insert( self, - document: Json, + document: T, return_new: bool = False, sync: Optional[bool] = None, silent: bool = False, @@ -2739,7 +2741,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def update( self, - document: Json, + document: Union[Json, T], check_rev: bool = True, merge: bool = True, keep_none: bool = True, @@ -2831,7 +2833,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def replace( self, - document: Json, + document: T, check_rev: bool = True, return_new: bool = False, return_old: bool = False, @@ -2916,7 +2918,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def delete( self, - document: Union[str, Json], + document: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, ignore_missing: bool = False, @@ -2995,7 +2997,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: return self._execute(request, response_handler) -class VertexCollection(Collection): +class VertexCollection(Collection[T]): """Vertex collection API wrapper. :param connection: HTTP connection. @@ -3005,7 +3007,7 @@ class VertexCollection(Collection): """ def __init__( - self, connection: Connection, executor: ApiExecutor, graph: str, name: str + self, connection: Connection[T], executor: ApiExecutor, graph: str, name: str ) -> None: super().__init__(connection, executor, name) self._graph = graph @@ -3070,7 +3072,7 @@ def response_handler(resp: Response) -> Optional[Json]: def insert( self, - vertex: Json, + vertex: T, sync: Optional[bool] = None, silent: bool = False, return_new: bool = False, @@ -3119,7 +3121,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def update( self, - vertex: Json, + vertex: Union[Json, T], check_rev: bool = True, keep_none: bool = True, sync: Optional[bool] = None, @@ -3189,7 +3191,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def replace( self, - vertex: Json, + vertex: T, check_rev: bool = True, sync: Optional[bool] = None, silent: bool = False, @@ -3253,7 +3255,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def delete( self, - vertex: Union[str, Json], + vertex: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, ignore_missing: bool = False, @@ -3315,7 +3317,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: return self._execute(request, response_handler) -class EdgeCollection(Collection): +class EdgeCollection(Collection[T]): """ArangoDB edge collection API wrapper. :param connection: HTTP connection. @@ -3388,7 +3390,7 @@ def response_handler(resp: Response) -> Optional[Json]: def insert( self, - edge: Json, + edge: T, sync: Optional[bool] = None, silent: bool = False, return_new: bool = False, @@ -3438,7 +3440,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def update( self, - edge: Json, + edge: Union[Json, T], check_rev: bool = True, keep_none: bool = True, sync: Optional[bool] = None, @@ -3508,7 +3510,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def replace( self, - edge: Json, + edge: T, check_rev: bool = True, sync: Optional[bool] = None, silent: bool = False, @@ -3573,7 +3575,7 @@ def response_handler(resp: Response) -> Union[bool, Json]: def delete( self, - edge: Union[str, Json], + edge: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, ignore_missing: bool = False, @@ -3635,8 +3637,8 @@ def response_handler(resp: Response) -> Union[bool, Json]: def link( self, - from_vertex: Union[str, Json], - to_vertex: Union[str, Json], + from_vertex: Union[str, Json, T], + to_vertex: Union[str, Json, T], data: Optional[Json] = None, sync: Optional[bool] = None, silent: bool = False, @@ -3672,7 +3674,7 @@ def link( def edges( self, - vertex: Union[str, Json], + vertex: Union[str, Json, T], direction: Optional[str] = None, allow_dirty_read: bool = False, ) -> Result[Json]: diff --git a/arango/connection.py b/arango/connection.py index 15ff3c70..c5b2c5f7 100644 --- a/arango/connection.py +++ b/arango/connection.py @@ -10,7 +10,8 @@ import sys import time from abc import abstractmethod -from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union +from json import dumps, loads +from typing import Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union import jwt from jwt.exceptions import ExpiredSignatureError @@ -27,12 +28,15 @@ from arango.request import Request from arango.resolver import HostResolver from arango.response import Response -from arango.typings import Fields, Json +from arango.serializer import Deserializer, Serializer +from arango.typings import DataTypes, Fields, Json Connection = Union["BasicConnection", "JwtConnection", "JwtSuperuserConnection"] +T = TypeVar("T") -class BaseConnection: + +class BaseConnection(Generic[T]): """Base connection to a specific ArangoDB database.""" def __init__( @@ -42,8 +46,8 @@ def __init__( sessions: Sequence[Session], db_name: str, http_client: HTTPClient, - serializer: Callable[..., str], - deserializer: Callable[[str], Any], + serializer: Serializer[T], + deserializer: Deserializer[DataTypes], request_compression: Optional[RequestCompression] = None, response_compression: Optional[str] = None, ) -> None: @@ -76,7 +80,7 @@ def username(self) -> Optional[str]: """ return self._username - def serialize(self, obj: Any) -> str: + def serialize(self, obj: Union[T, Sequence[T]]) -> str: """Serialize the given object. :param obj: JSON object to serialize. @@ -86,7 +90,7 @@ def serialize(self, obj: Any) -> str: """ return self._serializer(obj) - def deserialize(self, string: str) -> Any: + def deserialize(self, string: str) -> DataTypes: """De-serialize the string and return the object. :param string: String to de-serialize. @@ -194,7 +198,7 @@ def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Respo headers=parent_response.headers, status_code=parent_response.status_code, status_text=parent_response.status_text, - raw_body=self.serialize(body), + raw_body=dumps(body, separators=(",", ":")), ) resp.body = body resp.error_code = body["errorNum"] @@ -202,7 +206,9 @@ def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Respo resp.is_success = False return resp - def normalize_data(self, data: Any) -> Union[str, MultipartEncoder, None]: + def normalize_data( + self, data: Union[str, MultipartEncoder, T, List[T], None] + ) -> Union[str, MultipartEncoder, None]: """Normalize request data. :param data: Request data. @@ -247,7 +253,7 @@ def send_request(self, request: Request) -> Response: # pragma: no cover raise NotImplementedError -class BasicConnection(BaseConnection): +class BasicConnection(BaseConnection[T]): """Connection to specific ArangoDB database using basic authentication. :param hosts: Host URL or list of URLs (coordinators in a cluster). @@ -279,8 +285,8 @@ def __init__( username: str, password: str, http_client: HTTPClient, - serializer: Callable[..., str], - deserializer: Callable[[str], Any], + serializer: Serializer[T], + deserializer: Deserializer[DataTypes], request_compression: Optional[RequestCompression] = None, response_compression: Optional[str] = None, ) -> None: @@ -310,7 +316,7 @@ def send_request(self, request: Request) -> Response: return self.process_request(host_index, request, auth=self._auth) -class JwtConnection(BaseConnection): +class JwtConnection(BaseConnection[T]): """Connection to specific ArangoDB database using JWT authentication. :param hosts: Host URL or list of URLs (coordinators in a cluster). @@ -340,8 +346,8 @@ def __init__( sessions: Sequence[Session], db_name: str, http_client: HTTPClient, - serializer: Callable[..., str], - deserializer: Callable[[str], Any], + serializer: Serializer[T], + deserializer: Deserializer[DataTypes], username: Optional[str] = None, password: Optional[str] = None, user_token: Optional[str] = None, @@ -420,6 +426,7 @@ def refresh_token(self) -> None: method="post", endpoint="/_open/auth", data={"username": self._username, "password": self._password}, + deserialize=False, ) host_index = self._host_resolver.get_host_index() @@ -429,6 +436,10 @@ def refresh_token(self) -> None: if not resp.is_success: raise JWTAuthError(resp, request) + resp.body = loads(resp.raw_body) + if "jwt" not in resp.body: + raise JWTRefreshError(f"JWT token not found in response body: {resp.body}") + self.set_token(resp.body["jwt"]) def set_token(self, token: str) -> None: @@ -461,7 +472,7 @@ def set_token(self, token: str) -> None: self._auth_header = f"bearer {self._token}" -class JwtSuperuserConnection(BaseConnection): +class JwtSuperuserConnection(BaseConnection[T]): """Connection to specific ArangoDB database using superuser JWT. :param hosts: Host URL or list of URLs (coordinators in a cluster). @@ -489,9 +500,9 @@ def __init__( sessions: Sequence[Session], db_name: str, http_client: HTTPClient, - serializer: Callable[..., str], - deserializer: Callable[[str], Any], superuser_token: str, + serializer: Serializer[T], + deserializer: Deserializer[DataTypes], request_compression: Optional[RequestCompression] = None, response_compression: Optional[str] = None, ) -> None: diff --git a/arango/cursor.py b/arango/cursor.py index 83da8881..2166e4c6 100644 --- a/arango/cursor.py +++ b/arango/cursor.py @@ -50,7 +50,7 @@ class Cursor: def __init__( self, - connection: BaseConnection, + connection: BaseConnection[Any], init_data: Json, cursor_type: str = "cursor", allow_retry: bool = False, diff --git a/arango/database.py b/arango/database.py index 7348c256..7cd24744 100644 --- a/arango/database.py +++ b/arango/database.py @@ -8,7 +8,7 @@ from datetime import datetime from numbers import Number -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union from warnings import warn from arango.api import ApiGroup @@ -117,8 +117,10 @@ from arango.utils import get_col_name from arango.wal import WAL +T = TypeVar("T") -class Database(ApiGroup): + +class Database(ApiGroup, Generic[T]): """Base class for Database API wrappers.""" def __getitem__(self, name: str) -> StandardCollection: @@ -152,7 +154,7 @@ def name(self) -> str: return self.db_name @property - def aql(self) -> AQL: + def aql(self) -> AQL[T]: """Return AQL (ArangoDB Query Language) API wrapper. :return: AQL API wrapper. @@ -1335,7 +1337,7 @@ def response_handler(resp: Response) -> bool: # Collection Management # ######################### - def collection(self, name: str) -> StandardCollection: + def collection(self, name: str) -> StandardCollection[T]: """Return the standard collection API wrapper. :param name: Collection name. @@ -1408,7 +1410,7 @@ def create_collection( write_concern: Optional[int] = None, schema: Optional[Json] = None, computedValues: Optional[Jsons] = None, - ) -> Result[StandardCollection]: + ) -> Result[StandardCollection[T]]: """Create a new collection. :param name: Collection name. @@ -1793,7 +1795,10 @@ def response_handler(resp: Response) -> bool: ####################### def has_document( - self, document: Json, rev: Optional[str] = None, check_rev: bool = True + self, + document: Union[str, Json, T], + rev: Optional[str] = None, + check_rev: bool = True, ) -> Result[bool]: """Check if a document exists. @@ -1815,7 +1820,10 @@ def has_document( ) def document( - self, document: Json, rev: Optional[str] = None, check_rev: bool = True + self, + document: Union[str, Json, T], + rev: Optional[str] = None, + check_rev: bool = True, ) -> Result[Optional[Json]]: """Return a document. @@ -1839,7 +1847,7 @@ def document( def insert_document( self, collection: str, - document: Json, + document: T, return_new: bool = False, sync: Optional[bool] = None, silent: bool = False, @@ -1903,7 +1911,7 @@ def insert_document( def update_document( self, - document: Json, + document: Union[Json, T], check_rev: bool = True, merge: bool = True, keep_none: bool = True, @@ -1954,7 +1962,7 @@ def update_document( def replace_document( self, - document: Json, + document: T, check_rev: bool = True, return_new: bool = False, return_old: bool = False, @@ -1996,7 +2004,7 @@ def replace_document( def delete_document( self, - document: Union[str, Json], + document: Union[str, Json, T], rev: Optional[str] = None, check_rev: bool = True, ignore_missing: bool = False, @@ -2948,10 +2956,10 @@ def response_handler(resp: Response) -> Json: return self._execute(request, response_handler) -class StandardDatabase(Database): +class StandardDatabase(Database[T]): """Standard database API wrapper.""" - def __init__(self, connection: Connection) -> None: + def __init__(self, connection: Connection[T]) -> None: super().__init__(connection=connection, executor=DefaultApiExecutor(connection)) def __repr__(self) -> str: @@ -3071,7 +3079,7 @@ def begin_controlled_execution( return OverloadControlDatabase(self._conn, max_queue_time_seconds) -class AsyncDatabase(Database): +class AsyncDatabase(Database[T]): """Database API wrapper tailored specifically for async execution. See :func:`arango.database.StandardDatabase.begin_async_execution`. @@ -3094,7 +3102,7 @@ def __repr__(self) -> str: return f"" -class BatchDatabase(Database): +class BatchDatabase(Database[T]): """Database API wrapper tailored specifically for batch execution. .. note:: @@ -3163,7 +3171,7 @@ def commit(self) -> Optional[Sequence[BatchJob[Any]]]: return self._executor.commit() -class TransactionDatabase(Database): +class TransactionDatabase(Database[T]): """Database API wrapper tailored specifically for transactions. See :func:`arango.database.StandardDatabase.begin_transaction`. @@ -3261,7 +3269,7 @@ def abort_transaction(self) -> bool: return self._executor.abort() -class OverloadControlDatabase(Database): +class OverloadControlDatabase(Database[T]): """Database API wrapper tailored to gracefully handle server overload scenarios. See :func:`arango.database.StandardDatabase.begin_controlled_execution`. diff --git a/arango/serializer.py b/arango/serializer.py index ea72ccc7..6ee07e7d 100644 --- a/arango/serializer.py +++ b/arango/serializer.py @@ -6,42 +6,45 @@ ] from json import dumps, loads -from typing import Any, Generic, TypeVar +from typing import Generic, Sequence, TypeVar, Union + +from arango.typings import DataTypes, Json T = TypeVar("T") class Serializer(Generic[T]): """ - Serializer interface + Serializer interface. + For the use of bulk operations, it must also support List[T]. """ - def __call__(self, data: T) -> str: + def __call__(self, obj: Union[T, Sequence[T]]) -> str: raise NotImplementedError -class Deserializer: +class Deserializer(Generic[T]): """ De-serializer interface """ - def __call__(self, data: str) -> Any: + def __call__(self, s: str) -> T: raise NotImplementedError -class JsonSerializer(Serializer[Any]): +class JsonSerializer(Serializer[Json]): """ Default JSON serializer """ - def __call__(self, data: Any) -> str: - return dumps(data, separators=(",", ":")) + def __call__(self, obj: Union[Json, Sequence[Json]]) -> str: + return dumps(obj, separators=(",", ":")) -class JsonDeserializer(Deserializer): +class JsonDeserializer(Deserializer[DataTypes]): """ Default JSON de-serializer """ - def __call__(self, data: str) -> Any: - return loads(data) + def __call__(self, s: str) -> DataTypes: + return loads(s) diff --git a/arango/utils.py b/arango/utils.py index 359b1e37..daa9df1a 100644 --- a/arango/utils.py +++ b/arango/utils.py @@ -96,12 +96,12 @@ def is_none_or_bool(obj: Any) -> bool: return obj is None or isinstance(obj, bool) -def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]: +def get_batches(elements: Sequence[Any], batch_size: int) -> Iterator[Sequence[Any]]: """Generator to split a list in batches of (maximum) **batch_size** elements each. :param elements: The list of elements. - :type elements: Sequence[Json] + :type elements: Sequence[Any] :param batch_size: Max number of elements per batch. :type batch_size: int """