Skip to content

Commit

Permalink
Add Support for QuerySet slicing operator
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvalbenarie committed Apr 27, 2024
1 parent b72c175 commit 08a6cb6
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Added
- Add __eq__ method to Q to more easily test dynamically-built queries (#1506)
- Added PlainToTsQuery function for postgres (#1347)
- Allow field's default keyword to be async function (#1498)
- Add support for queryset slicing. (#1341)

Fixed
^^^^^
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Contributors
* Paul Serov ``@thakryptex``
* Stanislav Zmiev ``@Ovsyanka83``
* Waket Zheng ``@waketzheng``
* Yuval Ben-Arie ``@yuvalbenarie``

Special Thanks
==============
Expand Down
52 changes: 52 additions & 0 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,58 @@ async def test_offset_negative(self):
with self.assertRaisesRegex(ParamsError, "Offset should be non-negative number"):
await IntFields.all().offset(-10)

async def test_slicing_start_and_stop(self) -> None:
sliced_queryset = IntFields.all().order_by("intnum")[1:5]
manually_sliced_queryset = IntFields.all().order_by("intnum").offset(1).limit(4)
self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset)

async def test_slicing_only_limit(self) -> None:
sliced_queryset = IntFields.all().order_by("intnum")[:5]
manually_sliced_queryset = IntFields.all().order_by("intnum").limit(5)
self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset)

async def test_slicing_only_offset(self) -> None:
sliced_queryset = IntFields.all().order_by("intnum")[5:]
manually_sliced_queryset = IntFields.all().order_by("intnum").offset(5)
self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset)

async def test_slicing_count(self) -> None:
queryset = IntFields.all().order_by("intnum")[1:5]
self.assertEqual(await queryset.count(), 4)

def test_slicing_negative_values(self) -> None:
with self.assertRaisesRegex(
expected_exception=ParamsError,
expected_regex="Slice start should be non-negative number or None.",
):
_ = IntFields.all()[-1:]

with self.assertRaisesRegex(
expected_exception=ParamsError,
expected_regex="Slice stop should be non-negative number greater that slice start, "
"or None.",
):
_ = IntFields.all()[:-1]

def test_slicing_stop_before_start(self) -> None:
with self.assertRaisesRegex(
expected_exception=ParamsError,
expected_regex="Slice stop should be non-negative number greater that slice start, "
"or None.",
):
_ = IntFields.all()[2:1]

async def test_slicing_steps(self) -> None:
sliced_queryset = IntFields.all().order_by("intnum")[::1]
manually_sliced_queryset = IntFields.all().order_by("intnum")
self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset)

with self.assertRaisesRegex(
expected_exception=ParamsError,
expected_regex="Slice steps should be 1 or None.",
):
_ = IntFields.all()[::2]

async def test_join_count(self):
tour = await Tournament.create(name="moo")
await MinRelation.create(tournament=tour)
Expand Down
33 changes: 33 additions & 0 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,39 @@ def offset(self, offset: int) -> "QuerySet[MODEL]":
queryset._limit = 1000000
return queryset

def __getitem__(self, key: slice) -> "QuerySet[MODEL]":
"""
Query offset and limit for Queryset.
:raises ParamsError: QuerySet indices must be slices.
:raises ParamsError: Slice steps should be 1 or None.
:raises ParamsError: Slice start should be non-negative number or None.
:raises ParamsError: Slice stop should be non-negative number greater that slice start,
or None.
"""
if not isinstance(key, slice):
raise ParamsError("QuerySet indices must be slices.")

if not (key.step is None or (isinstance(key.step, int) and key.step == 1)):
raise ParamsError("Slice steps should be 1 or None.")

start = key.start if key.start is not None else 0

if not isinstance(start, int) or start < 0:
raise ParamsError("Slice start should be non-negative number or None.")
if key.stop is not None and (not isinstance(key.stop, int) or key.stop <= start):
raise ParamsError(
"Slice stop should be non-negative number greater that slice start, or None.",
)

queryset = self.offset(start)
if key.stop:
queryset = queryset.limit(key.stop - start)
return queryset

def distinct(self) -> "QuerySet[MODEL]":
"""
Make QuerySet distinct.
Expand Down

0 comments on commit 08a6cb6

Please sign in to comment.