Skip to content

Commit 583fd97

Browse files
committed
more progress
1 parent 8cd39f5 commit 583fd97

File tree

12 files changed

+413
-286
lines changed

12 files changed

+413
-286
lines changed

django_mongodb/fields/array.py

Lines changed: 38 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,24 @@
11
import json
22

3-
from django.contrib.postgres.forms import SimpleArrayField
43
from django.contrib.postgres.validators import ArrayMaxLengthValidator
54
from django.core import checks, exceptions
6-
from django.db.models import DecimalField, Field, Func, Transform, Value
5+
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
76
from django.db.models.fields.mixins import CheckFieldDefaultMixin
7+
from django.db.models.lookups import In
88
from django.utils.translation import gettext_lazy as _
99

10-
__all__ = ["ArrayField"]
10+
from django_mongodb.forms import SimpleArrayField
1111

12+
from ..utils import prefix_validation_error
1213

13-
from django.core.exceptions import ValidationError
14-
from django.utils.functional import SimpleLazyObject
15-
from django.utils.text import format_lazy
14+
__all__ = ["ArrayField"]
1615

1716

1817
class AttributeSetter:
1918
def __init__(self, name, value):
2019
setattr(self, name, value)
2120

2221

23-
def prefix_validation_error(error, prefix, code, params):
24-
"""
25-
Prefix a validation error message while maintaining the existing
26-
validation data structure.
27-
"""
28-
if error.error_list == [error]:
29-
error_params = error.params or {}
30-
return ValidationError(
31-
# We can't simply concatenate messages since they might require
32-
# their associated parameters to be expressed correctly which
33-
# is not something `format_lazy` does. For example, proxied
34-
# ngettext calls require a count parameter and are converted
35-
# to an empty string if they are missing it.
36-
message=format_lazy(
37-
"{} {}",
38-
SimpleLazyObject(lambda: prefix % params),
39-
SimpleLazyObject(lambda: error.message % error_params),
40-
),
41-
code=code,
42-
params={**error_params, **params},
43-
)
44-
return ValidationError(
45-
[prefix_validation_error(e, prefix, code, params) for e in error.error_list]
46-
)
47-
48-
4922
class ArrayField(CheckFieldDefaultMixin, Field):
5023
empty_strings_allowed = False
5124
default_error_messages = {
@@ -293,55 +266,44 @@ def _rhs_not_none_values(self, rhs):
293266
yield True
294267

295268

296-
# @ArrayField.register_lookup
297-
# class ArrayContains(ArrayRHSMixin, lookups.DataContains):
298-
# pass
299-
300-
301-
# @ArrayField.register_lookup
302-
# class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
303-
# pass
304-
305-
306269
# @ArrayField.register_lookup
307270
# class ArrayExact(ArrayRHSMixin, Exact):
308-
# pass
271+
# pass
309272

310273

311-
# @ArrayField.register_lookup
312-
# class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
313-
# pass
274+
@ArrayField.register_lookup
275+
class ArrayLenTransform(Transform):
276+
lookup_name = "len"
277+
output_field = IntegerField()
314278

315-
316-
# @ArrayField.register_lookup
317-
# class ArrayLenTransform(Transform):
318-
# lookup_name = "len"
319-
# output_field = IntegerField()
320-
321-
# def as_sql(self, compiler, connection):
322-
# lhs, params = compiler.compile(self.lhs)
323-
# # Distinguish NULL and empty arrays
324-
# return (
325-
# "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
326-
# "coalesce(array_length(%(lhs)s, 1), 0) END"
327-
# ) % {"lhs": lhs}, params * 2
279+
def as_sql(self, compiler, connection):
280+
lhs, params = compiler.compile(self.lhs)
281+
# Distinguish NULL and empty arrays
282+
return (
283+
(
284+
"" # "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
285+
# "coalesce(array_length(%(lhs)s, 1), 0) END"
286+
)
287+
% {},
288+
params * 2,
289+
)
328290

329291

330-
# @ArrayField.register_lookup
331-
# class ArrayInLookup(In):
332-
# def get_prep_lookup(self):
333-
# values = super().get_prep_lookup()
334-
# if hasattr(values, "resolve_expression"):
335-
# return values
336-
# # In.process_rhs() expects values to be hashable, so convert lists
337-
# # to tuples.
338-
# prepared_values = []
339-
# for value in values:
340-
# if hasattr(value, "resolve_expression"):
341-
# prepared_values.append(value)
342-
# else:
343-
# prepared_values.append(tuple(value))
344-
# return prepared_values
292+
@ArrayField.register_lookup
293+
class ArrayInLookup(In):
294+
def get_prep_lookup(self):
295+
values = super().get_prep_lookup()
296+
if hasattr(values, "resolve_expression"):
297+
return values
298+
# In.process_rhs() expects values to be hashable, so convert lists
299+
# to tuples.
300+
prepared_values = []
301+
for value in values:
302+
if hasattr(value, "resolve_expression"):
303+
prepared_values.append(value)
304+
else:
305+
prepared_values.append(tuple(value))
306+
return prepared_values
345307

346308

347309
class IndexTransform(Transform):
@@ -388,6 +350,5 @@ def __init__(self, start, end):
388350
self.start = start
389351
self.end = end
390352

391-
392-
# def __call__(self, *args, **kwargs):
393-
# return SliceTransform(self.start, self.end, *args, **kwargs)
353+
def __call__(self, *args, **kwargs):
354+
return SliceTransform(self.start, self.end, *args, **kwargs)

django_mongodb/forms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .array import * # NOQA: F403

0 commit comments

Comments
 (0)