From 97177d2de527e881bc642bce37cbac3044ddb7d0 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Mon, 4 Sep 2023 18:45:45 +0300 Subject: [PATCH] Add hook func for sqlalchemy apply filter --- fastapi_filters/ext/sqlalchemy.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/fastapi_filters/ext/sqlalchemy.py b/fastapi_filters/ext/sqlalchemy.py index e567eae..040099b 100644 --- a/fastapi_filters/ext/sqlalchemy.py +++ b/fastapi_filters/ext/sqlalchemy.py @@ -90,9 +90,9 @@ def _default_apply_filter(*_: Any) -> Any: raise NotImplementedError -custom_apply_filter: ConfigVar[ - Callable[[TSelectable, EntityNamespace, str, AbstractFilterOperator, Any], TSelectable] -] = ConfigVar( +ApplyFilterFunc: TypeAlias = Callable[[TSelectable, EntityNamespace, str, AbstractFilterOperator, Any], TSelectable] + +custom_apply_filter: ConfigVar[ApplyFilterFunc[Any]] = ConfigVar( "apply_filter", default=_default_apply_filter, ) @@ -104,11 +104,20 @@ def _apply_filter( field: str, op: AbstractFilterOperator, val: Any, + apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, ) -> TSelectable: custom_apply_filter_impl = custom_apply_filter.get() try: - cond = custom_apply_filter_impl(stmt, ns, field, op, val) + cond = None + if apply_filter: + try: + cond = apply_filter(stmt, ns, field, op, val) + except NotImplementedError: + pass + + if cond is None: + cond = custom_apply_filter_impl(stmt, ns, field, op, val) except NotImplementedError: try: cond = DEFAULT_FILTERS[op](ns[field], val) @@ -124,6 +133,7 @@ def apply_filters( *, remapping: Optional[Mapping[str, str]] = None, additional: Optional[EntityNamespace] = None, + apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, ) -> TSelectable: if isinstance(filters, FilterSet): filters = filters.filter_values @@ -135,7 +145,7 @@ def apply_filters( field = remapping.get(field, field) for op, val in field_filters.items(): - stmt = _apply_filter(stmt, ns, field, op, val) + stmt = _apply_filter(stmt, ns, field, op, val, apply_filter) return stmt @@ -173,8 +183,9 @@ def apply_filters_and_sorting( *, remapping: Optional[Mapping[str, str]] = None, additional: Optional[EntityNamespace] = None, + apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, ) -> TSelectable: - stmt = apply_filters(stmt, filters, remapping=remapping, additional=additional) + stmt = apply_filters(stmt, filters, remapping=remapping, additional=additional, apply_filter=apply_filter) stmt = apply_sorting(stmt, sorting, remapping=remapping, additional=additional) return stmt