Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Dec 9, 2024
1 parent 8cd39f5 commit 583fd97
Show file tree
Hide file tree
Showing 12 changed files with 413 additions and 286 deletions.
115 changes: 38 additions & 77 deletions django_mongodb/fields/array.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,24 @@
import json

from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import DecimalField, Field, Func, Transform, Value
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
from django.db.models.fields.mixins import CheckFieldDefaultMixin
from django.db.models.lookups import In
from django.utils.translation import gettext_lazy as _

__all__ = ["ArrayField"]
from django_mongodb.forms import SimpleArrayField

from ..utils import prefix_validation_error

from django.core.exceptions import ValidationError
from django.utils.functional import SimpleLazyObject
from django.utils.text import format_lazy
__all__ = ["ArrayField"]


class AttributeSetter:
def __init__(self, name, value):
setattr(self, name, value)


def prefix_validation_error(error, prefix, code, params):
"""
Prefix a validation error message while maintaining the existing
validation data structure.
"""
if error.error_list == [error]:
error_params = error.params or {}
return ValidationError(
# We can't simply concatenate messages since they might require
# their associated parameters to be expressed correctly which
# is not something `format_lazy` does. For example, proxied
# ngettext calls require a count parameter and are converted
# to an empty string if they are missing it.
message=format_lazy(
"{} {}",
SimpleLazyObject(lambda: prefix % params),
SimpleLazyObject(lambda: error.message % error_params),
),
code=code,
params={**error_params, **params},
)
return ValidationError(
[prefix_validation_error(e, prefix, code, params) for e in error.error_list]
)


class ArrayField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
default_error_messages = {
Expand Down Expand Up @@ -293,55 +266,44 @@ def _rhs_not_none_values(self, rhs):
yield True


# @ArrayField.register_lookup
# class ArrayContains(ArrayRHSMixin, lookups.DataContains):
# pass


# @ArrayField.register_lookup
# class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
# pass


# @ArrayField.register_lookup
# class ArrayExact(ArrayRHSMixin, Exact):
# pass
# pass


# @ArrayField.register_lookup
# class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
# pass
@ArrayField.register_lookup
class ArrayLenTransform(Transform):
lookup_name = "len"
output_field = IntegerField()


# @ArrayField.register_lookup
# class ArrayLenTransform(Transform):
# lookup_name = "len"
# output_field = IntegerField()

# def as_sql(self, compiler, connection):
# lhs, params = compiler.compile(self.lhs)
# # Distinguish NULL and empty arrays
# return (
# "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
# "coalesce(array_length(%(lhs)s, 1), 0) END"
# ) % {"lhs": lhs}, params * 2
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
# Distinguish NULL and empty arrays
return (
(
"" # "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
# "coalesce(array_length(%(lhs)s, 1), 0) END"
)
% {},
params * 2,
)


# @ArrayField.register_lookup
# class ArrayInLookup(In):
# def get_prep_lookup(self):
# values = super().get_prep_lookup()
# if hasattr(values, "resolve_expression"):
# return values
# # In.process_rhs() expects values to be hashable, so convert lists
# # to tuples.
# prepared_values = []
# for value in values:
# if hasattr(value, "resolve_expression"):
# prepared_values.append(value)
# else:
# prepared_values.append(tuple(value))
# return prepared_values
@ArrayField.register_lookup
class ArrayInLookup(In):
def get_prep_lookup(self):
values = super().get_prep_lookup()
if hasattr(values, "resolve_expression"):
return values
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
prepared_values = []
for value in values:
if hasattr(value, "resolve_expression"):
prepared_values.append(value)
else:
prepared_values.append(tuple(value))
return prepared_values


class IndexTransform(Transform):
Expand Down Expand Up @@ -388,6 +350,5 @@ def __init__(self, start, end):
self.start = start
self.end = end


# def __call__(self, *args, **kwargs):
# return SliceTransform(self.start, self.end, *args, **kwargs)
def __call__(self, *args, **kwargs):
return SliceTransform(self.start, self.end, *args, **kwargs)
1 change: 1 addition & 0 deletions django_mongodb/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .array import * # NOQA: F403
Loading

0 comments on commit 583fd97

Please sign in to comment.