Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FireO in FastAPI-filters #190

Open
ADR-007 opened this issue Mar 6, 2023 · 0 comments
Open

FireO in FastAPI-filters #190

ADR-007 opened this issue Mar 6, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@ADR-007
Copy link
Collaborator

ADR-007 commented Mar 6, 2023

Hi!

I implemented FireO in FastAPI-filters, but I'm not sure when I'll make a PR to the library.
Please let me know if you want to use it. I'll try to find a time for it then. :)

Code snippet:

from copy import deepcopy
from types import UnionType
from typing import Any, Dict, Generic, Literal, Optional, Tuple, Type, TypeVar, Union

import fastapi_filter.base.filter as filter_lib
from fastapi import Query, params
from fastapi_filter.base.filter import BaseFilterModel
from fireo.managers.managers import Manager
from fireo.queries.filter_query import FilterQuery
from fireo.queries.query_set import QuerySet
from pydantic import root_validator, validator
from pydantic.fields import SHAPE_LIST, FieldInfo, ModelField, Undefined

_orm_operator_filter = {
    "": lambda query, field_name, value: query.filter(field_name, "==", value),
    "not_eq": lambda query, field_name, value: query.filter(field_name, "!=", value),
    "gt": lambda query, field_name, value: query.filter(field_name, ">", value),
    "gte": lambda query, field_name, value: query.filter(field_name, ">=", value),
    "in": lambda query, field_name, value: query.filter(field_name, "in", value),
    "isnull": lambda query, field_name, value: query.filter(field_name, ("==" if value is True else "!="), None),
    "lt": lambda query, field_name, value: query.filter(field_name, "<>", value),
    "lte": lambda query, field_name, value: query.filter(field_name, "<=", value),
    "not_in": lambda query, field_name, value: query.filter(field_name, "not_in", value),
    "contains": lambda query, field_name, value: query.filter(field_name, "array-contains", value),
    "overlap": lambda query, field_name, value: query.filter(field_name, "array-contains-any", value),
    "startswith": lambda query, field_name, value: (
        query.filter(field_name, ">=", value).filter(field_name, "<", value + "\ufffd")
    ),
}
_orm_op_conflicts_with_sorting = set(_orm_operator_filter) - {"", "in", "not_in", "isnull"}


class FireoFilter(BaseFilterModel):
    """Base filter for Firestore related filters.

    Example:
        ```python
        class MyModel(Model):
            name: TextField(required=True)
            count: NumberField(int_only=True)
            created_at: DatetimeField()

        class MyModelFilter(Filter):
            id: Optional[int]
            id__in: Optional[str]
            count: Optional[int]
            count__lte: Optional[int]
            created_at__gt: Optional[datetime]
            name__not_eq: Optional[str]
            name__not_in: Optional[list[str]]
        ```
    """

    @validator("*", pre=True)
    def split_str(cls, value, field: ModelField):
        return value

    @root_validator()
    def validate_filter_and_sort_combinations(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate that the filter and sort combinations are valid for Firestore.

        Changes:
            - If there is an inequality filter, the first sort order must be the same.
            - If there is an inequality filter and no sort order, it will be added.
        """
        orders = values.get(cls.Constants.ordering_field_name, None)

        unequal_filter_fields = set()
        for raw_field_name, value in values.items():
            if value is None:
                continue

            field_name, _, raw_operator = raw_field_name.partition("__")
            if raw_operator in _orm_op_conflicts_with_sorting:
                unequal_filter_fields.add(field_name)

        if not unequal_filter_fields:
            return values

        if len(unequal_filter_fields) > 1:
            raise ValueError(
                f"Cannot have inequality on multiple fields: {unequal_filter_fields}"
            )

        if not orders:
            # Pagination does not work without this ordering
            values[cls.Constants.ordering_field_name] = list(unequal_filter_fields)
            return values

        first_order = orders[0].lstrip("+-")
        filter_field = unequal_filter_fields.pop()
        if filter_field != first_order:
            raise ValueError(
                f"Inequality filter property and first sort order must be the same: {filter_field} and {first_order}"
            )

        return values

    def filter(self, query: FilterQuery | QuerySet | Manager) -> FilterQuery:
        for raw_field_name, value in self.filtering_fields:
            field_value = getattr(self, raw_field_name)
            if isinstance(field_value, FireoFilter):
                query = field_value.filter(query)
                continue

            field_name, _, raw_operator = raw_field_name.partition("__")
            query = _orm_operator_filter[raw_operator](query, field_name, value)

        return query

    def sort(self, query: FilterQuery | QuerySet | Manager) -> FilterQuery:
        if not self.ordering_values:
            return query

        for order in self.ordering_values:
            query = query.order(order)

        return query


ListItem = TypeVar('ListItem')


class CommaSepList(list, Generic[ListItem]):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if isinstance(v, list) and len(v) == 1:
            v = v[0].split(',')

        return v


class Order(str, Generic[ListItem]):
    def __class_getitem__(cls, items) -> Type[Literal[ListItem]]:  # type: ignore
        if not isinstance(items, tuple):
            items = (items,)

        assert {type(item) for item in items} == {str}
        options = tuple(
            f'{neg}{arg}'
            for neg in ['', '-']
            for arg in items
        )
        return Literal[options]  # type: ignore


def _list_to_comma_list(type_):
    if getattr(type_, "__origin__", None) is list:
        return CommaSepList[type_.__args__[0]]
    return type_


def _list_to_str_fields(Filter: Type[BaseFilterModel]):
    """Prepare filter fields to be used in query params.

    Unlike the original implementation, this one:
        - allows to use lists in query params as multiple values for the same field
        - split comma separated values in query params to lists, so "split_str"
            is no longer needed
    """
    ret: Dict[str, Tuple[Union[object, Type], Optional[FieldInfo]]] = {}
    for f in Filter.__fields__.values():
        field_info = deepcopy(f.field_info)
        if not isinstance(field_info.default, params.Query):
            if field_info.default is not Undefined:
                default = field_info.default
            elif f.required:
                default = ...
            else:
                default = None

            field_info.default = Query(default)

        field_type = Filter.__annotations__.get(f.name, f.outer_type_)
        if f.shape == SHAPE_LIST:
            if issubclass(type(field_type), UnionType):
                items = []
                for arg in field_type.__args__:
                    items.append(_list_to_comma_list(arg))
                new_field_type = Union[tuple(items)]  # type: ignore
            else:
                new_field_type = _list_to_comma_list(field_type)
            ret[f.name] = (new_field_type, field_info)
        else:
            ret[f.name] = (field_type if f.required else Optional[field_type], field_info)

    return ret


filter_lib._list_to_str_fields = _list_to_str_fields
@ADR-007 ADR-007 added the enhancement New feature or request label Mar 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant