diff --git a/pyodmongo/queries/query_string.py b/pyodmongo/queries/query_string.py index fd74b4f..289db13 100644 --- a/pyodmongo/queries/query_string.py +++ b/pyodmongo/queries/query_string.py @@ -1,8 +1,8 @@ from ..models.db_model import DbModel from ..models.db_field_info import DbField -from ..models.query_operators import QueryOperator -from .operators import and_, sort -from typing import Type +from ..models.query_operators import QueryOperator, LogicalOperator +from .operators import and_, or_, sort +from typing import Type, Literal from datetime import datetime import re @@ -69,7 +69,8 @@ def js_regex_to_python(js_regex_str): def mount_query_filter( Model: Type[DbModel], items: dict, - initial_comparison_operators: list[QueryOperator], + query_operator: Literal["and", "or"] = "and", + initial_comparison_operators: list[QueryOperator] = [], ) -> QueryOperator: """ Constructs a MongoDB query filter from a dictionary of conditions and initializes @@ -100,36 +101,68 @@ def mount_query_filter( raise TypeError("Model must be a DbModel") sort_operators = None for key, value in items.items(): + key = key.strip() value = value.strip() if value == "": continue split_result = key.strip().rsplit(sep="_", maxsplit=1) - operator = f"${split_result[-1]}" - if operator not in ["$eq", "$gt", "$gte", "$in", "$lt", "$lte", "$ne", "$nin"]: - if operator in ["$sort"]: + operator = f"{split_result[-1]}" + if operator not in [ + "eq", + "gt", + "gte", + "in", + "lt", + "lte", + "ne", + "nin", + "and", + "or", + ]: + if operator in ["sort"]: value = eval(value) for v in value: v[0] = getattr(Model, v[0]) sort_operators = sort(*value) continue + if operator in ["and", "or"]: + value, _ = mount_query_filter( + Model=Model, + items=eval(value), + query_operator=operator, + initial_comparison_operators=[], + ) try: value = datetime.fromisoformat(value) except (TypeError, ValueError): try: + if type(value) == str and ( + value.capitalize() == "True" or value.capitalize() == "False" + ): + value = value.capitalize() value = eval(value) - except (NameError, SyntaxError): + except (NameError, SyntaxError, TypeError): value = value field_name = split_result[0] if type(value) is list: for index, item in enumerate(value): value[index] = js_regex_to_python(item) - try: - db_field_info: DbField = eval("Model." + field_name) - except AttributeError: - raise AttributeError(f"There's no field '{field_name}' in {Model.__name__}") - initial_comparison_operators.append( - db_field_info.comparison_operator(operator=operator, value=value) - ) + if type(value) != LogicalOperator: + try: + db_field_info: DbField = getattr(Model, field_name) + except AttributeError: + raise AttributeError( + f"There's no field '{field_name}' in {Model.__name__}" + ) + initial_comparison_operators.append( + db_field_info.comparison_operator(operator="$" + operator, value=value) + ) + else: + initial_comparison_operators.append(value) + pass if len(initial_comparison_operators) == 0: return None, sort_operators - return and_(*initial_comparison_operators), sort_operators + if query_operator == "or": + return or_(*initial_comparison_operators), sort_operators + else: + return and_(*initial_comparison_operators), sort_operators diff --git a/tests/test_queries.py b/tests/test_queries.py index b6ffe72..d8468d7 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -481,3 +481,24 @@ class MyModel(DbModel): def test_to_dict_query_operator_default(): assert QueryOperator().to_dict() is None + + +def test_mount_query_filter_with_logical_operators(): + class MyClass(DbModel): + attr_1: str + attr_2: int + attr_3: bool + _collection: ClassVar = "my_class" + + dict_input = { + "attr_3_eq": "true", + "_or": "{'attr_1_eq': 'value_1', 'attr_2_lte': '10'}", + } + query, _ = mount_query_filter(Model=MyClass, items=dict_input) + assert query == and_( + eq(MyClass.attr_3, True), + or_(eq(MyClass.attr_1, "value_1"), lte(MyClass.attr_2, 10)), + ) + assert query == (MyClass.attr_3 == True) & ( + (MyClass.attr_1 == "value_1") | (MyClass.attr_2 <= 10) + )