Skip to content

Commit

Permalink
Merge pull request #94 from Ariana-B/sqlalchemy-invalid-field-options
Browse files Browse the repository at this point in the history
Enable custom handling of undefined field attr in to_filter
  • Loading branch information
constantinius authored Jul 8, 2024
2 parents d51dbb0 + f0c7e9f commit 23f172c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
19 changes: 10 additions & 9 deletions pygeofilter/backends/sqlalchemy/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@


class SQLAlchemyFilterEvaluator(Evaluator):
def __init__(self, field_mapping):
def __init__(self, field_mapping, undefined_as_null):
self.field_mapping = field_mapping
self.undefined_as_null = undefined_as_null

@handle(ast.Not)
def not_(self, node, sub):
Expand Down Expand Up @@ -105,7 +106,7 @@ def bbox(self, node, lhs):

@handle(ast.Attribute)
def attribute(self, node):
return filters.attribute(node.name, self.field_mapping)
return filters.attribute(node.name, self.field_mapping, self.undefined_as_null)

@handle(ast.Arithmetic, subclasses=True)
def arithmetic(self, node, lhs, rhs):
Expand Down Expand Up @@ -133,15 +134,15 @@ def envelope(self, node):
return filters.parse_bbox([node.x1, node.y1, node.x2, node.y2])


def to_filter(ast, field_mapping=None):
"""Helper function to translate ECQL AST to Django Query expressions.
def to_filter(ast, field_mapping={}, undefined_as_null=None):
"""Helper function to translate ECQL AST to SQLAlchemy Query expressions.
:param ast: the abstract syntax tree
:param field_mapping: a dict mapping from the filter name to the Django
:param field_mapping: a dict mapping from the filter name to the SQLAlchemy
field lookup.
:param mapping_choices: a dict mapping field lookups to choices.
:param undefined_as_null: whether a name not present in field_mapping
should evaluate to null.
:type ast: :class:`Node`
:returns: a Django query object
:rtype: :class:`django.db.models.Q`
:returns: a SQLAlchemy query object
"""
return SQLAlchemyFilterEvaluator(field_mapping).evaluate(ast)
return SQLAlchemyFilterEvaluator(field_mapping, undefined_as_null).evaluate(ast)
16 changes: 11 additions & 5 deletions pygeofilter/backends/sqlalchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict, Optional

from pygeoif import shape
from sqlalchemy import and_, func, not_, or_
from sqlalchemy import and_, func, not_, or_, null


def parse_bbox(box, srid: Optional[int] = None):
Expand Down Expand Up @@ -257,15 +257,21 @@ def bbox(lhs, minx, miny, maxx, maxy, crs=4326):
return lhs.ST_Intersects(parse_bbox([minx, miny, maxx, maxy], crs))


def attribute(name, field_mapping=None):
def attribute(name, field_mapping={}, undefined_as_null: bool = None):
"""Create an attribute lookup expression using a field mapping dictionary.
:param name: the field filter name
:param field_mapping: the dictionary to use as a lookup.
:param undefined_as_null: how to handle a name not present in field_mapping
(None (default) - leave as-is; True - treat as null; False - throw error)
"""
field = field_mapping.get(name, name)

return field
if undefined_as_null is None:
return field_mapping.get(name, name)
if undefined_as_null:
# return null object if name is not found in field_mapping
return field_mapping.get(name, null())
# undefined_as_null is False, so raise KeyError if name not found
return field_mapping[name]


def literal(value):
Expand Down
22 changes: 20 additions & 2 deletions tests/backends/sqlalchemy/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def db_session(setup_database, connection):
transaction.rollback()


def evaluate(session, cql_expr, expected_ids):
def evaluate(session, cql_expr, expected_ids, filter_option=None):
ast = parse(cql_expr)
filters = to_filter(ast, FIELD_MAPPING)
filters = to_filter(ast, FIELD_MAPPING, filter_option)

q = session.query(Record).join(RecordMeta).filter(filters)
results = [row.identifier for row in q]
Expand Down Expand Up @@ -415,3 +415,21 @@ def test_arith_field_plus_mul_1(db_session):

def test_arith_field_plus_mul_2(db_session):
evaluate(db_session, "intMetaAttribute = 5 + intAttribute * 1.5", ("A",))


# handling undefined/invalid attributes


def test_undef_comp(db_session):
# treat undefined/invalid attribute as null
evaluate(db_session, "missingAttribute > 10", (), True)


def test_undef_isnull(db_session):
evaluate(db_session, "missingAttribute IS NULL", ("A", "B"), True)


def test_undef_comp_error(db_session):
# error if undefined/invalid attribute
with pytest.raises(KeyError):
evaluate(db_session, "missingAttribute > 10", (), False)

0 comments on commit 23f172c

Please sign in to comment.