Skip to content

Commit

Permalink
implement sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-oleshkevich committed Apr 16, 2024
1 parent df07e53 commit b4f376f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 8 deletions.
60 changes: 52 additions & 8 deletions ohmyadmin/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import abc
import datetime
import enum
import functools
import math
import operator
import re
import typing

Expand Down Expand Up @@ -38,6 +40,10 @@ class Query(abc.ABC, typing.Generic[T]): # pragma: no cover
def filter(self, **filters: typing.Any) -> typing.Self:
raise NotImplementedError()

@abc.abstractmethod
def order_by(self, *fields: str) -> typing.Self:
raise NotImplementedError()

@abc.abstractmethod
async def one(self) -> T | None:
raise NotImplementedError()
Expand Down Expand Up @@ -99,10 +105,11 @@ class MemoryQuery(Query[T]):

def __init__(self, models: typing.Sequence[T]) -> None:
self.models = models
self.filters: list[QueryFilter[T]] = []
self._filters: list[QueryFilter[T]] = []
self._order_by: list[str] = []

def filter(self, *ops: QueryFilter[T], **filters: typing.Any) -> typing.Self:
self.filters.extend(ops)
self._filters.extend(ops)
for field, value in filters.items():
*field_parts, op = field.split("__")
case_sensitive = True
Expand All @@ -119,22 +126,26 @@ def filter(self, *ops: QueryFilter[T], **filters: typing.Any) -> typing.Self:
value=value,
case_sensitive=case_sensitive,
)
self.filters.append(filter_)
self._filters.append(filter_)

if op in NumericOperation:
filter_ = MemoryNumericFilter(field=field_name, op=NumericOperation(op), value=value)
self.filters.append(filter_)
self._filters.append(filter_)

if op in DateOperation:
filter_ = MemoryDateFilter(field=field_name, op=DateOperation(op), value=value)
self.filters.append(filter_)
self._filters.append(filter_)

if op in IsOperation:
filter_ = MemoryIsFilter(field=field_name, op=IsOperation(op), value=value)
self.filters.append(filter_)
self._filters.append(filter_)

return self

def order_by(self, *fields: str) -> typing.Self:
self._order_by = list(fields)
return self

async def one(self) -> T | None:
total = len(self.models)
if total > 1:
Expand All @@ -150,9 +161,42 @@ async def paginate(self, page: int, page_size: int) -> Paginator[T]:

def _get_filtered_query(self) -> Query[T]:
query: Query[T] = self
for filter_ in self.filters:
for filter_ in self._filters:
query = filter_.apply(query)
return query

models = typing.cast(MemoryQuery[T], query).models
return MemoryQuery(models=multikeysort(models, self._order_by))


def multikeysort(items: typing.Iterable[T], columns: list[str]) -> list[T]:
"""
Thanks to to https://stackoverflow.com/questions/4233476/sort-a-list-by-multiple-attributes
"""

def get_comparers() -> typing.Sequence[tuple[typing.Callable[[T], int], int]]:
comparers = []

for col in columns:
col = col.strip()
if col.startswith("-"): # If descending, strip '-' and create a comparer with reverse order
key = operator.attrgetter(col[1:])
order = -1
else: # If ascending, use the column directly
key = operator.attrgetter(col)
order = 1

comparers.append((key, order))
return comparers

def custom_compare(left: T, right: T) -> int:
"""Custom comparison function to handle multiple keys."""
for fn, reverse in get_comparers():
result = (fn(left) > fn(right)) - (fn(left) < fn(right))
if result != 0:
return result * reverse
return 0

return sorted(items, key=functools.cmp_to_key(custom_compare))


class MemorySource(DataSource[T]):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_datasource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import random

import pytest
from starlette.requests import Request
Expand Down Expand Up @@ -820,3 +821,33 @@ async def test_one_raises_for_many(self, http_request: Request) -> None:
with pytest.raises(MultipleObjectsError):
datasource = MemorySource([root_user, user_user])
assert await datasource.query(http_request).one() is None


class TestSorting:
async def test_order_by(self, http_request: Request) -> None:
datasource = MemorySource([user_user, root_user])
assert await datasource.query(http_request).order_by("id").all() == [root_user, user_user]
assert await datasource.query(http_request).order_by("-id").all() == [user_user, root_user]

async def test_order_by_multifield(self, http_request: Request) -> None:
user1 = User(id=1, child=1, parent=1)
user2 = User(id=2, child=1, parent=2)
user3 = User(id=3, child=2, parent=2)
user4 = User(id=4, child=2, parent=3)

dataset = [user1, user2, user3, user4]
random.shuffle(dataset)

datasource = MemorySource(dataset)
assert await datasource.query(http_request).order_by("child", "parent").all() == [
user1,
user2,
user3,
user4,
]
assert await datasource.query(http_request).order_by("child", "-parent").all() == [
user2,
user1,
user4,
user3,
]

0 comments on commit b4f376f

Please sign in to comment.