diff --git a/arango/collection.py b/arango/collection.py index 4fd848fa..3aea6f81 100644 --- a/arango/collection.py +++ b/arango/collection.py @@ -42,7 +42,7 @@ from arango.response import Response from arango.result import Result from arango.typings import Fields, Headers, Json, Params -from arango.utils import get_doc_id, is_none_or_int, is_none_or_str +from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str class Collection(ApiGroup): @@ -1934,7 +1934,8 @@ def import_bulk( overwrite: Optional[bool] = None, on_duplicate: Optional[str] = None, sync: Optional[bool] = None, - ) -> Result[Json]: + batch_size: Optional[int] = None, + ) -> Union[Result[Json], List[Result[Json]]]: """Insert multiple documents into the collection. .. note:: @@ -1984,8 +1985,17 @@ def import_bulk( :type on_duplicate: str :param sync: Block until operation is synchronized to disk. :type sync: bool | None + :param batch_size: Split up **documents** into batches of max length + **batch_size** and import them in a loop on the client side. If + **batch_size** is specified, the return type of this method + changes from a result object to a list of result objects. + IMPORTANT NOTE: this parameter may go through breaking changes + in the future where the return type may not be a list of result + objects anymore. Use it at your own risk, and avoid + depending on the return value if possible. + :type batch_size: int :return: Result of the bulk import. - :rtype: dict + :rtype: dict | list[dict] :raise arango.exceptions.DocumentInsertError: If import fails. """ documents = [self._ensure_key_from_id(doc) for doc in documents] @@ -2006,21 +2016,35 @@ def import_bulk( if sync is not None: params["waitForSync"] = sync - request = Request( - method="post", - endpoint="/_api/import", - data=documents, - params=params, - write=self.name, - ) - def response_handler(resp: Response) -> Json: if resp.is_success: result: Json = resp.body return result raise DocumentInsertError(resp, request) - return self._execute(request, response_handler) + if batch_size is None: + request = Request( + method="post", + endpoint="/_api/import", + data=documents, + params=params, + write=self.name, + ) + + return self._execute(request, response_handler) + else: + results = [] + for batch in get_batches(documents, batch_size): + request = Request( + method="post", + endpoint="/_api/import", + data=batch, + params=params, + write=self.name, + ) + results.append(self._execute(request, response_handler)) + + return results class StandardCollection(Collection): diff --git a/arango/database.py b/arango/database.py index ea9ac058..5e9dcfa8 100644 --- a/arango/database.py +++ b/arango/database.py @@ -1225,11 +1225,13 @@ def create_graph( .. code-block:: python - { - 'edge_collection': 'teach', - 'from_vertex_collections': ['teachers'], - 'to_vertex_collections': ['lectures'] - } + [ + { + 'edge_collection': 'teach', + 'from_vertex_collections': ['teachers'], + 'to_vertex_collections': ['lectures'] + } + ] """ data: Json = {"name": name, "options": dict()} if edge_definitions is not None: diff --git a/arango/utils.py b/arango/utils.py index 42d7fff3..27d459ba 100644 --- a/arango/utils.py +++ b/arango/utils.py @@ -8,7 +8,7 @@ import logging from contextlib import contextmanager -from typing import Any, Iterator, Union +from typing import Any, Iterator, Sequence, Union from arango.exceptions import DocumentParseError from arango.typings import Json @@ -82,3 +82,16 @@ def is_none_or_str(obj: Any) -> bool: :rtype: bool """ return obj is None or isinstance(obj, str) + + +def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]: + """Generator to split a list in batches + of (maximum) **batch_size** elements each. + + :param elements: The list of elements. + :type elements: Sequence[Json] + :param batch_size: Max number of elements per batch. + :type batch_size: int + """ + for index in range(0, len(elements), batch_size): + yield elements[index : index + batch_size] diff --git a/tests/test_document.py b/tests/test_document.py index 95446da1..02c47bc4 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -1832,6 +1832,17 @@ def test_document_import_bulk(col, bad_col, docs): assert col[doc_key]["loc"] == doc["loc"] empty_collection(col) + # Test import bulk with batch_size + results = col.import_bulk(docs, batch_size=len(docs) // 2) + assert type(results) is list + assert len(results) == 2 + empty_collection(col) + + result = col.import_bulk(docs, batch_size=len(docs) * 2) + assert type(result) is list + assert len(result) == 1 + empty_collection(col) + # Test import bulk on_duplicate actions doc = docs[0] doc_key = doc["_key"]