Skip to content

Commit

Permalink
isnull filter for relationships and attributes on backend (#3717)
Browse files Browse the repository at this point in the history
* isnull filter for relationships and attributes on backend

* add graphql isnull filters

* restrict isnull filter

* account for self.filters is None
  • Loading branch information
ajtmccarty authored Jun 27, 2024
1 parent 57d2716 commit 1a85a35
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 17 deletions.
4 changes: 3 additions & 1 deletion backend/infrahub/core/query/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def default_attribute_query_filter( # pylint: disable=unused-argument,too
query_filter.append(QueryNode(name="i", labels=["Attribute"], params={"name": f"${param_prefix}_name"}))
query_params[f"{param_prefix}_name"] = name

if filter_name in ("value", "binary_address", "prefixlen"):
if filter_name in ("value", "binary_address", "prefixlen", "isnull"):
query_filter.append(QueryRel(labels=[RELATIONSHIP_TO_VALUE_LABEL]))

if filter_value is None:
Expand All @@ -232,6 +232,8 @@ async def default_attribute_query_filter( # pylint: disable=unused-argument,too
query_where.append(
f"toLower(toString(av.{filter_name})) CONTAINS toLower(toString(${param_prefix}_{filter_name}))"
)
elif filter_name == "isnull":
query_filter.append(QueryNode(name="av", labels=["AttributeValue"]))
elif support_profiles:
query_filter.append(QueryNode(name="av", labels=["AttributeValue"]))
query_where.append(f"(av.{filter_name} = ${param_prefix}_{filter_name} OR av.is_default)")
Expand Down
38 changes: 31 additions & 7 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ class FieldAttributeRequirement:

@property
def supports_profile(self) -> bool:
return bool(self.field and self.field.is_attribute and self.field_attr_name in ("value", "values"))
return bool(self.field and self.field.is_attribute and self.field_attr_name in ("value", "values", "isnull"))

@property
def is_filter(self) -> bool:
Expand Down Expand Up @@ -701,6 +701,20 @@ def profile_final_value_query_variable(self) -> str:
def final_value_query_variable(self) -> str:
return f"attr{self.index}_final_value"

@property
def comparison_operator(self) -> str:
if self.field_attr_name == "isnull":
return "=" if self.field_attr_value is True else "<>"
if self.field_attr_name == "values":
return "IN"
return "="

@property
def field_attr_comparison_value(self) -> Any:
if self.field_attr_name == "isnull":
return "NULL"
return self.field_attr_value


class NodeGetListQuery(Query):
name = "node_get_list"
Expand All @@ -712,9 +726,23 @@ def __init__(
self.filters = filters
self.partial_match = partial_match
self._variables_to_track = ["n", "rb"]
self._validate_filters()

super().__init__(**kwargs)

def _validate_filters(self) -> None:
if not self.filters:
return
filter_errors = []
for filter_str in self.filters:
split_filter = filter_str.split("__")
if len(split_filter) > 2 and split_filter[-1] == "isnull":
filter_errors.append(
f"{filter_str} is not allowed: 'isnull' is not supported for attributes of relationships"
)
if filter_errors:
raise RuntimeError(*filter_errors)

def _track_variable(self, variable: str) -> None:
if variable not in self._variables_to_track:
self._variables_to_track.append(variable)
Expand Down Expand Up @@ -1016,18 +1044,14 @@ def _add_final_filter(self, field_attribute_requirements: list[FieldAttributeReq
if not far.is_filter or not far.supports_profile:
continue
var_name = f"final_attr_value{far.index}"
self.params[var_name] = far.field_attr_value
self.params[var_name] = far.field_attr_comparison_value
if self.partial_match:
where_parts.append(
f"toLower(toString({far.final_value_query_variable})) CONTAINS toLower(toString(${var_name}))"
)
continue
if far.field_attr_name == "values":
operator = "IN"
else:
operator = "="

where_parts.append(f"{far.final_value_query_variable} {operator} ${var_name}")
where_parts.append(f"{far.final_value_query_variable} {far.comparison_operator} ${var_name}")
if where_parts:
where_str = "WHERE " + " AND ".join(where_parts)
self.add_to_query(where_str)
Expand Down
20 changes: 16 additions & 4 deletions backend/infrahub/core/query/subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ async def build_subquery_filter(
support_profiles: bool = False,
extra_tail_properties: Optional[dict[str, str]] = None,
) -> tuple[str, dict[str, Any], str]:
support_profiles = support_profiles and field and field.is_attribute and filter_name in ("value", "values")
support_profiles = (
support_profiles and field and field.is_attribute and filter_name in ("value", "values", "isnull")
)
params = {}
prefix = f"{result_prefix}{subquery_idx}"

Expand Down Expand Up @@ -62,13 +64,23 @@ async def build_subquery_filter(
to_return = f"{node_alias} AS {prefix}"
with_extra = ""
final_with_extra = ""
if extra_tail_properties:
is_isnull = filter_name == "isnull"
if extra_tail_properties or is_isnull:
tail_node = field_filter[-1]
with_extra += f", {tail_node.name}"
final_with_extra += f", latest_node_details[2] AS {tail_node.name}"
if extra_tail_properties:
for variable_name, tail_property in extra_tail_properties.items():
to_return += f", {tail_node.name}.{tail_property} AS {variable_name}"
match = "OPTIONAL MATCH" if optional_match else "MATCH"
match = "MATCH"
if optional_match or is_isnull:
match = "OPTIONAL MATCH"
is_active_filter = "latest_node_details[0] = TRUE"
if is_isnull and filter_value is True:
if field is not None and field.is_relationship:
is_active_filter = "latest_node_details[0] = FALSE OR latest_node_details[0] IS NULL"
elif field is not None and field.is_attribute:
is_active_filter = "(latest_node_details[2]).value = 'NULL'"
query = f"""
WITH {node_alias}
{match} path = {filter_str}
Expand All @@ -81,7 +93,7 @@ async def build_subquery_filter(
all(r IN relationships(path) WHERE r.status = "active") AS is_active{with_extra}
ORDER BY branch_level DESC, froms[-1] DESC, froms[-2] DESC
WITH head(collect([is_active, {node_alias}{with_extra}])) AS latest_node_details
WHERE latest_node_details[0] = TRUE
WHERE {is_active_filter}
WITH latest_node_details[1] AS {node_alias}{final_with_extra}
RETURN {to_return}
"""
Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/core/schema/relationship_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def get_query_filter(

return query_filter, query_params, query_where

if filter_name == "ids":
if filter_name in ("ids", "isnull"):
query_filter.extend(
[
QueryRel(name="r1", labels=[rel_type], direction=rels_direction["r1"]),
Expand All @@ -125,7 +125,7 @@ async def get_query_filter(
]
)

if filter_value is not None:
if filter_name == "ids" and filter_value is not None:
query_params[f"{prefix}_peer_ids"] = filter_value
query_where.append(f"peer.uuid IN ${prefix}_peer_ids")

Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,12 +783,14 @@ def generate_filters(
default_filters: list[str] = list(filters.keys())

filters["ids"] = graphene.List(graphene.ID)
if not top_level:
filters["isnull"] = graphene.Boolean()

for attr in schema.attributes:
attr_kind = get_attr_kind(node_schema=schema, attr_schema=attr)
filters.update(
get_attribute_type(kind=attr_kind).get_graphql_filters(
name=attr.name, include_properties=include_properties
name=attr.name, include_properties=include_properties, include_isnull=top_level
)
)

Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/graphql/types/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ async def get_paginated_list(cls, fields: dict, context: GraphqlContext, **kwarg
response: dict[str, Any] = {"edges": []}
offset = kwargs.pop("offset", None)
limit = kwargs.pop("limit", None)
filters = {key: value for key, value in kwargs.items() if ("__" in key and value) or key == "ids"}
filters = {
key: value for key, value in kwargs.items() if ("__" in key and value is not None) or key == "ids"
}
if "count" in fields:
response["count"] = await NodeManager.count(
db=db,
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,15 @@ def get_graphql_type_name(cls) -> str:
return cls.get_graphql_type().__name__

@classmethod
def get_graphql_filters(cls, name: str, include_properties: bool = True) -> dict[str, typing.Any]:
def get_graphql_filters(
cls, name: str, include_properties: bool = True, include_isnull: bool = False
) -> dict[str, typing.Any]:
filters: dict[str, typing.Any] = {}
attr_class = cls.get_infrahub_class()
filters[f"{name}__value"] = cls.graphql_filter()
filters[f"{name}__values"] = graphene.List(cls.graphql_filter)
if include_isnull:
filters[f"{name}__isnull"] = graphene.Boolean()

if not include_properties:
return filters
Expand Down
151 changes: 151 additions & 0 deletions backend/tests/unit/core/test_node_get_list_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from random import randint

import pytest

from infrahub.core.branch import Branch
from infrahub.core.constants import (
BranchSupportType,
Expand Down Expand Up @@ -49,6 +51,155 @@ async def test_query_NodeGetListQuery_filter_ids(
assert query.get_node_ids() == [person_albert_main.id, person_jim_main.id, person_john_main.id]


async def test_query_NodeGetListQuery_filter_attribute_isnull(
db: InfrahubDatabase, person_albert_main, person_alfred_main, person_jane_main, branch: Branch
):
person_schema = registry.schema.get(name="TestPerson", branch=branch, duplicate=False)
person_branch = await NodeManager.get_one(db=db, branch=branch, id=person_albert_main.id)
person_branch.height.value = None
await person_branch.save(db=db)

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"height__isnull": True},
)
await query.execute(db=db)
assert query.get_node_ids() == [person_albert_main.id]

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"height__isnull": False},
)
await query.execute(db=db)
assert set(query.get_node_ids()) == {person_alfred_main.id, person_jane_main.id}

person_branch = await NodeManager.get_one(db=db, branch=branch, id=person_albert_main.id)
person_branch.height.value = 155
await person_branch.save(db=db)

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"height__isnull": True},
)
await query.execute(db=db)
assert query.get_node_ids() == []

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"height__isnull": False},
)
await query.execute(db=db)
assert set(query.get_node_ids()) == {person_albert_main.id, person_alfred_main.id, person_jane_main.id}


async def test_query_NodeGetListQuery_filter_relationship_isnull_one(
db: InfrahubDatabase, car_accord_main, car_camry_main, car_volt_main, person_jane_main, branch: Branch
):
car_schema = registry.schema.get(name="TestCar", branch=branch, duplicate=False)
owner_rel = car_schema.get_relationship(name="owner")
owner_rel.optional = True
car_branch = await NodeManager.get_one(db=db, branch=branch, id=car_camry_main.id)
await car_branch.owner.update(db=db, data=[None])
await car_branch.save(db=db)

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=car_schema,
filters={"owner__isnull": True},
)
await query.execute(db=db)
assert query.get_node_ids() == [car_camry_main.id]

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=car_schema,
filters={"owner__isnull": False},
)
await query.execute(db=db)
assert set(query.get_node_ids()) == {car_accord_main.id, car_volt_main.id}

car_branch = await NodeManager.get_one(db=db, branch=branch, id=car_camry_main.id)
await car_branch.owner.update(db=db, data=person_jane_main)
await car_branch.save(db=db)

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=car_schema,
filters={"owner__isnull": True},
)
await query.execute(db=db)
assert query.get_node_ids() == []

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=car_schema,
filters={"owner__isnull": False},
)
await query.execute(db=db)
assert set(query.get_node_ids()) == {car_camry_main.id, car_accord_main.id, car_volt_main.id}


async def test_query_NodeGetListQuery_filter_relationship_isnull_many(
db: InfrahubDatabase,
car_accord_main,
car_camry_main,
person_albert_main,
person_alfred_main,
person_jane_main,
person_john_main,
branch: Branch,
):
person_schema = registry.schema.get(name="TestPerson", branch=branch)
person_schema.order_by = ["name__value"]
car_branch = await NodeManager.get_one(db=db, branch=branch, id=car_camry_main.id)
await car_branch.owner.update(db=db, data=person_albert_main)
await car_branch.save(db=db)

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"cars__isnull": True},
)
await query.execute(db=db)
assert query.get_node_ids() == [person_alfred_main.id, person_jane_main.id]

query = await NodeGetListQuery.init(
db=db,
branch=branch,
schema=person_schema,
filters={"cars__isnull": False},
)
await query.execute(db=db)
assert query.get_node_ids() == [person_albert_main.id, person_john_main.id]


async def test_query_NodeGetListQuery_filter_relationship_attribute_isnull_not_allowed(
db: InfrahubDatabase, car_person_schema, default_branch
):
car_schema = registry.schema.get(name="TestCar", branch=default_branch, duplicate=False)

with pytest.raises(RuntimeError, match=r"owner__height__isnull is not allowed"):
await NodeGetListQuery.init(
db=db,
branch=default_branch,
schema=car_schema,
filters={"owner__height__isnull": True},
)


async def test_query_NodeGetListQuery_filter_height(
db: InfrahubDatabase, person_john_main, person_jim_main, person_albert_main, person_alfred_main, branch: Branch
):
Expand Down
Loading

0 comments on commit 1a85a35

Please sign in to comment.