4
4
from django .core import checks , exceptions
5
5
from django .db .models import DecimalField , Field , Func , IntegerField , Transform , Value
6
6
from django .db .models .fields .mixins import CheckFieldDefaultMixin
7
- from django .db .models .lookups import FieldGetDbPrepValueMixin , In , Lookup
7
+ from django .db .models .lookups import Exact , FieldGetDbPrepValueMixin , In , Lookup
8
8
from django .utils .translation import gettext_lazy as _
9
9
10
10
from django_mongodb .forms import SimpleArrayField
@@ -235,6 +235,11 @@ def formfield(self, **kwargs):
235
235
)
236
236
237
237
238
+ class Array (Func ):
239
+ def as_mql (self , compiler , connection ):
240
+ return [expr .as_mql (compiler , connection ) for expr in self .get_source_expressions ()]
241
+
242
+
238
243
class ArrayRHSMixin :
239
244
def __init__ (self , lhs , rhs ):
240
245
# Don't wrap arrays that contains only None values, psycopg doesn't
@@ -246,18 +251,9 @@ def __init__(self, lhs, rhs):
246
251
field = lhs .output_field
247
252
value = Value (field .base_field .get_prep_value (value ))
248
253
expressions .append (value )
249
- rhs = Func (
250
- * expressions ,
251
- function = "ARRAY" ,
252
- template = "%(function)s[%(expressions)s]" ,
253
- )
254
+ rhs = Array (* expressions )
254
255
super ().__init__ (lhs , rhs )
255
256
256
- def process_rhs (self , compiler , connection ):
257
- rhs , rhs_params = super ().process_rhs (compiler , connection )
258
- cast_type = self .lhs .output_field .cast_db_type (connection )
259
- return f"{ rhs } ::{ cast_type } " , rhs_params
260
-
261
257
def _rhs_not_none_values (self , rhs ):
262
258
for x in rhs :
263
259
if isinstance (x , list | tuple ):
@@ -267,29 +263,29 @@ def _rhs_not_none_values(self, rhs):
267
263
268
264
269
265
@ArrayField .register_lookup
270
- class ArrayContains (FieldGetDbPrepValueMixin , Lookup ):
266
+ class ArrayContains (ArrayRHSMixin , FieldGetDbPrepValueMixin , Lookup ):
271
267
lookup_name = "contains"
272
268
273
269
def as_mql (self , compiler , connection ):
274
270
lhs_mql = process_lhs (self , compiler , connection )
275
271
value = process_rhs (self , compiler , connection )
276
272
return {
277
- "$gt " : [
273
+ "$eq " : [
278
274
{
279
275
"$cond" : {
280
276
"if" : {"$eq" : [lhs_mql , None ]},
281
- "then" : None ,
282
- "else" : {"$size " : { "$setIntersection" : [ lhs_mql , value ]} },
277
+ "then" : False ,
278
+ "else" : {"$setIsSubset " : [ value , lhs_mql ] },
283
279
}
284
280
},
285
- 0 ,
281
+ True ,
286
282
]
287
283
}
288
284
289
285
290
- # @ArrayField.register_lookup
291
- # class ArrayExact(ArrayRHSMixin, Exact):
292
- # pass
286
+ @ArrayField .register_lookup
287
+ class ArrayExact (ArrayRHSMixin , Exact ):
288
+ pass
293
289
294
290
295
291
@ArrayField .register_lookup
0 commit comments