Skip to content

Commit

Permalink
Add batch_size parameter in import_bulk method (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna authored Jun 24, 2022
1 parent 21c9e5d commit 917f699
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
48 changes: 36 additions & 12 deletions arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from arango.response import Response
from arango.result import Result
from arango.typings import Fields, Headers, Json, Params
from arango.utils import get_doc_id, is_none_or_int, is_none_or_str
from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str


class Collection(ApiGroup):
Expand Down Expand Up @@ -1934,7 +1934,8 @@ def import_bulk(
overwrite: Optional[bool] = None,
on_duplicate: Optional[str] = None,
sync: Optional[bool] = None,
) -> Result[Json]:
batch_size: Optional[int] = None,
) -> Union[Result[Json], List[Result[Json]]]:
"""Insert multiple documents into the collection.
.. note::
Expand Down Expand Up @@ -1984,8 +1985,17 @@ def import_bulk(
:type on_duplicate: str
:param sync: Block until operation is synchronized to disk.
:type sync: bool | None
:param batch_size: Split up **documents** into batches of max length
**batch_size** and import them in a loop on the client side. If
**batch_size** is specified, the return type of this method
changes from a result object to a list of result objects.
IMPORTANT NOTE: this parameter may go through breaking changes
in the future where the return type may not be a list of result
objects anymore. Use it at your own risk, and avoid
depending on the return value if possible.
:type batch_size: int
:return: Result of the bulk import.
:rtype: dict
:rtype: dict | list[dict]
:raise arango.exceptions.DocumentInsertError: If import fails.
"""
documents = [self._ensure_key_from_id(doc) for doc in documents]
Expand All @@ -2006,21 +2016,35 @@ def import_bulk(
if sync is not None:
params["waitForSync"] = sync

request = Request(
method="post",
endpoint="/_api/import",
data=documents,
params=params,
write=self.name,
)

def response_handler(resp: Response) -> Json:
if resp.is_success:
result: Json = resp.body
return result
raise DocumentInsertError(resp, request)

return self._execute(request, response_handler)
if batch_size is None:
request = Request(
method="post",
endpoint="/_api/import",
data=documents,
params=params,
write=self.name,
)

return self._execute(request, response_handler)
else:
results = []
for batch in get_batches(documents, batch_size):
request = Request(
method="post",
endpoint="/_api/import",
data=batch,
params=params,
write=self.name,
)
results.append(self._execute(request, response_handler))

return results


class StandardCollection(Collection):
Expand Down
12 changes: 7 additions & 5 deletions arango/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,11 +1225,13 @@ def create_graph(
.. code-block:: python
{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
}
[
{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
}
]
"""
data: Json = {"name": name, "options": dict()}
if edge_definitions is not None:
Expand Down
15 changes: 14 additions & 1 deletion arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from contextlib import contextmanager
from typing import Any, Iterator, Union
from typing import Any, Iterator, Sequence, Union

from arango.exceptions import DocumentParseError
from arango.typings import Json
Expand Down Expand Up @@ -82,3 +82,16 @@ def is_none_or_str(obj: Any) -> bool:
:rtype: bool
"""
return obj is None or isinstance(obj, str)


def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]:
"""Generator to split a list in batches
of (maximum) **batch_size** elements each.
:param elements: The list of elements.
:type elements: Sequence[Json]
:param batch_size: Max number of elements per batch.
:type batch_size: int
"""
for index in range(0, len(elements), batch_size):
yield elements[index : index + batch_size]
11 changes: 11 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,17 @@ def test_document_import_bulk(col, bad_col, docs):
assert col[doc_key]["loc"] == doc["loc"]
empty_collection(col)

# Test import bulk with batch_size
results = col.import_bulk(docs, batch_size=len(docs) // 2)
assert type(results) is list
assert len(results) == 2
empty_collection(col)

result = col.import_bulk(docs, batch_size=len(docs) * 2)
assert type(result) is list
assert len(result) == 1
empty_collection(col)

# Test import bulk on_duplicate actions
doc = docs[0]
doc_key = doc["_key"]
Expand Down

0 comments on commit 917f699

Please sign in to comment.