From b68b13f279992ad6eb8c536c9804502e95f5a0a1 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Fri, 12 Jul 2024 18:09:15 -0400 Subject: [PATCH] add xor support for Q objects --- .github/workflows/test-python.yml | 1 + django_mongodb/query.py | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index e8c53d81..7548992b 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -99,6 +99,7 @@ jobs: sessions_tests timezones update + xor_lookups docs: name: Docs Checks diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 0f606625..fc684f89 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -1,11 +1,14 @@ -from functools import wraps +from functools import reduce, wraps +from operator import add as add_operator from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError -from django.db.models import Value +from django.db.models.expressions import Case, Value, When +from django.db.models.functions import Mod +from django.db.models.lookups import Exact from django.db.models.sql.constants import INNER from django.db.models.sql.datastructures import Join -from django.db.models.sql.where import AND, XOR, WhereNode +from django.db.models.sql.where import AND, OR, XOR, WhereNode from pymongo import ASCENDING, DESCENDING from pymongo.errors import DuplicateKeyError, PyMongoError @@ -219,8 +222,21 @@ def where_node(self, compiler, connection): if self.connector == AND: operator = "$and" elif self.connector == XOR: - # https://github.com/mongodb-labs/django-mongodb/issues/27 - raise NotImplementedError("XOR is not yet supported.") + # MongoDB doesn't support $xor, so convert: + # a XOR b XOR c XOR ... + # to: + # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1 + # The result of an n-ary XOR is true when an odd number of operands + # are true. + lhs = self.__class__(self.children, OR) + rhs_sum = reduce( + add_operator, + (Case(When(c, then=1), default=0) for c in self.children), + ) + if len(self.children) > 2: + rhs_sum = Mod(rhs_sum, 2) + rhs = Exact(1, rhs_sum) + return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) else: operator = "$or"