diff --git a/grapple/types/structures.py b/grapple/types/structures.py index 3262c4b2..45b6ff8e 100644 --- a/grapple/types/structures.py +++ b/grapple/types/structures.py @@ -21,6 +21,19 @@ def parse_literal(ast, _variables=None): return return_value +class SearchOperatorEnum(graphene.Enum): + """ + Enum for search operator. + """ + AND = "and" + OR = "or" + + def __str__(self): + # the core search parser expects the operator to be a string. + # the default __str__ returns SearchOperatorEnum.AND/OR, + # this __str__ returns the value and/or for compatibility. + return self.value + class QuerySetList(graphene.List): """ List type with arguments used by Django's query sets. @@ -31,6 +44,8 @@ class QuerySetList(graphene.List): * ``limit`` * ``offset`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_limit: Enable limit argument. @@ -39,6 +54,10 @@ class QuerySetList(graphene.List): :type enable_offset: bool :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ @@ -47,6 +66,8 @@ def __init__(self, of_type, *args, **kwargs): enable_limit = kwargs.pop("enable_limit", True) enable_offset = kwargs.pop("enable_offset", True) enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) enable_order = kwargs.pop("enable_order", True) # Check if the type is a Django model type. Do not perform the @@ -92,6 +113,22 @@ def __init__(self, of_type, *args, **kwargs): graphene.String, description=_("Filter the results using Wagtail's search."), ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_field"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) @@ -138,21 +175,29 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): """ Paginated QuerySet type with arguments used by Django's query sets. - This type setts the following arguments on itself: + This type sets the following arguments on itself: * ``id`` * ``page`` * ``per_page`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) enable_order = kwargs.pop("enable_order", True) required = kwargs.get("required", False) type_name = type_class if isinstance(type_class, str) else type_class.__name__ @@ -199,6 +244,22 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): kwargs["search_query"] = graphene.Argument( graphene.String, description=_("Filter the results using Wagtail's search.") ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_field"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) diff --git a/grapple/utils.py b/grapple/utils.py index ede62155..c0a97af0 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -8,6 +8,7 @@ from wagtail import VERSION as WAGTAIL_VERSION from wagtail.models import Site from wagtail.search.index import class_is_indexed +from wagtail.search.utils import parse_query_string from .settings import grapple_settings from .types.structures import BasePaginatedType, PaginationType @@ -100,6 +101,8 @@ def resolve_queryset( id=None, order=None, collection=None, + search_operator="and", + search_fields=None, **kwargs, ): """ @@ -121,6 +124,11 @@ def resolve_queryset( :type order: str :param collection: Use Wagtail's collection id to filter images or documents :type collection: int + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ qs = qs.all() if id is None else qs.filter(pk=id) @@ -147,7 +155,14 @@ def resolve_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, str(search_operator)) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") @@ -188,7 +203,16 @@ def get_paginated_result(qs, page, per_page): def resolve_paginated_queryset( - qs, info, page=None, per_page=None, search_query=None, id=None, order=None, **kwargs + qs, + info, + page=None, + per_page=None, + id=None, + order=None, + search_query=None, + search_operator="and", + search_fields=None, + **kwargs, ): """ Add page, per_page and search capabilities to the query. This contains @@ -202,11 +226,16 @@ def resolve_paginated_queryset( :type id: int :param per_page: The maximum number of items to include on a page. :type per_page: int + :param order: Order the query set using the Django QuerySet order_by format. + :type order: str :param search_query: Using Wagtail search, exclude objects that do not match the search query. :type search_query: str - :param order: Order the query set using the Django QuerySet order_by format. - :type order: str + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ page = int(page or 1) per_page = min( @@ -231,7 +260,14 @@ def resolve_paginated_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, search_operator) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") diff --git a/tests/test_grapple.py b/tests/test_grapple.py index 1c2445f9..8d2b0d6f 100644 --- a/tests/test_grapple.py +++ b/tests/test_grapple.py @@ -558,8 +558,8 @@ def test_explicit_order(self): executed = self.client.execute( query, variables={"searchQuery": "Gamma", "order": "-title"} ) - page_data = executed["data"].get("pages") + page_data = executed["data"].get("pages") self.assertEqual(len(page_data), 6) self.assertEqual(page_data[0]["title"], "Gamma Gamma") self.assertEqual(page_data[1]["title"], "Gamma Beta") @@ -569,6 +569,69 @@ def test_explicit_order(self): self.assertEqual(page_data[5]["title"], "Alpha Gamma") + def test_search_operator_default(self): + """ default operator is and""" + query = """ + query($searchQuery: String) { + pages(searchQuery: $searchQuery) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + + def test_search_operator_and(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "AND"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + + def test_search_operator_or(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "OR"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 10) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta") + self.assertEqual(page_data[5]["title"], "Beta Alpha") + self.assertEqual(page_data[6]["title"], "Beta Beta") + self.assertEqual(page_data[7]["title"], "Beta Gamma") + self.assertEqual(page_data[8]["title"], "Gamma Alpha") + self.assertEqual(page_data[9]["title"], "Gamma Beta") + + class PageUrlPathTest(BaseGrappleTest): def _query_by_path(self, path, *, in_site=False): query = """