Skip to content

Commit

Permalink
add support for sort in collection.find function
Browse files Browse the repository at this point in the history
  • Loading branch information
Deniz Alpaslan committed Jan 26, 2025
1 parent a7ff90d commit 32cdc62
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ arango/version.py

# test results
*_results.txt

# devcontainers
.devcontainer
7 changes: 6 additions & 1 deletion arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
from arango.typings import Fields, Headers, Json, Jsons, Params
from arango.utils import (
build_filter_conditions,
build_sort_expression,
get_batches,
get_doc_id,
is_none_or_bool,
is_none_or_int,
is_none_or_str,
validate_sort_parameters,
)


Expand Down Expand Up @@ -753,6 +755,7 @@ def find(
skip: Optional[int] = None,
limit: Optional[int] = None,
allow_dirty_read: bool = False,
sort: Sequence[Json] = [],
) -> Result[Cursor]:
"""Return all documents that match the given filters.
Expand All @@ -771,16 +774,18 @@ def find(
assert isinstance(filters, dict), "filters must be a dict"
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
if sort:
validate_sort_parameters(sort)

skip_val = skip if skip is not None else 0
limit_val = limit if limit is not None else "null"
query = f"""
FOR doc IN @@collection
{build_filter_conditions(filters)}
LIMIT {skip_val}, {limit_val}
{build_sort_expression(sort)}
RETURN doc
"""

bind_vars = {"@collection": self.name}

request = Request(
Expand Down
39 changes: 39 additions & 0 deletions arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str:
conditions.append(f"doc.{field} == {json.dumps(v)}")

return "FILTER " + " AND ".join(conditions)


def validate_sort_parameters(sort: Sequence[Json]) -> bool:
"""Validate sort parameters for an AQL query.
:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: Validation success.
:rtype: bool
:raise arango.exceptions.DocumentGetError: If sort parameters are invalid.
"""
assert isinstance(sort, Sequence)
for param in sort:
if "sort_by" not in param or "sort_order" not in param:
raise DocumentParseError(
"Each sort parameter must have 'sort_by' and 'sort_order'."
)
if param["sort_order"].upper() not in ["ASC", "DESC"]:
raise DocumentParseError("'sort_order' must be either 'ASC' or 'DESC'")
return True


def build_sort_expression(sort: Sequence[Json]) -> str:
"""Build a sort condition for an AQL query.
:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: The complete AQL sort condition.
:rtype: str
"""
if not sort:
return ""

sort_chunks = []
for sort_param in sort:
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
sort_chunks.append(chunk)

return "SORT " + ", ".join(sort_chunks)
6 changes: 6 additions & 0 deletions docs/document.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper:
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve one or more matching documents, sorted by a field.
for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]):
assert student['_key'] == 'john'
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve a document by key.
students.get('john')

Expand Down
20 changes: 20 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs):
# Set up test documents
col.import_bulk(docs)

# Test find with sort expression (single field)
found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}]))
assert len(found) == 6
assert found[0]["text"] == "bar"
assert found[-1]["text"] == "foo"

# Test find with sort expression (multiple fields)
found = list(
col.find(
{},
sort=[
{"sort_by": "text", "sort_order": "ASC"},
{"sort_by": "val", "sort_order": "DESC"},
],
)
)
assert len(found) == 6
assert found[0]["val"] == 6
assert found[-1]["val"] == 1

# Test find (single match) with default options
found = list(col.find({"val": 2}))
assert len(found) == 1
Expand Down

0 comments on commit 32cdc62

Please sign in to comment.