Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing for graphql.resolver #4551

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/infrahub/graphql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from infrahub_sdk.utils import deep_merge_dict

if TYPE_CHECKING:
from infrahub.core.schema import NodeSchema
from infrahub.core.schema import MainSchemaTypes


@dataclass
Expand All @@ -26,13 +26,13 @@ class FieldEnricher:
fields: dict = field(default_factory=dict)


async def extract_selection(field_node: FieldNode, schema: NodeSchema) -> dict:
async def extract_selection(field_node: FieldNode, schema: MainSchemaTypes) -> dict:
graphql_extractor = GraphQLExtractor(field_node=field_node, schema=schema)
return await graphql_extractor.get_fields()


class GraphQLExtractor:
def __init__(self, field_node: FieldNode, schema: NodeSchema) -> None:
def __init__(self, field_node: FieldNode, schema: MainSchemaTypes) -> None:
self.field_node = field_node
self.schema = schema
self.typename_paths: dict[str, list[FieldEnricher]] = {}
Expand Down
114 changes: 81 additions & 33 deletions backend/infrahub/graphql/resolver.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from infrahub_sdk.utils import extract_fields
from infrahub_sdk.utils import extract_fields, extract_fields_first_node

from infrahub.core.constants import BranchSupportType, InfrahubKind, RelationshipHierarchyDirection
from infrahub.core.manager import NodeManager
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.core.schema import NodeSchema
from infrahub.exceptions import NodeNotFoundError

from .parser import extract_selection
from .permissions import get_permissions
from .types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED

SchemaType = TypeVar("SchemaType")

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
from graphql import GraphQLObjectType, GraphQLOutputType, GraphQLResolveInfo

from infrahub.core.schema import MainSchemaTypes, NodeSchema
from infrahub.core.schema import MainSchemaTypes
from infrahub.graphql.initialization import GraphqlContext
from infrahub.graphql.types import InfrahubObject
from infrahub.graphql.types.node import InfrahubObjectOptions


async def account_resolver(
root, # pylint: disable=unused-argument
root: Any, # pylint: disable=unused-argument
info: GraphQLResolveInfo,
) -> dict:
fields = await extract_fields(info.field_nodes[0].selection_set)
Expand All @@ -30,18 +35,20 @@ async def account_resolver(
async with context.db.start_session() as db:
results = await NodeManager.query(
schema=InfrahubKind.GENERICACCOUNT,
filters={"ids": [context.account_session.account_id]},
filters={"ids": [context.active_account_session.account_id]},
fields=fields,
db=db,
)
if results:
account_profile = await results[0].to_graphql(db=db, fields=fields)
return account_profile

raise NodeNotFoundError(node_type=InfrahubKind.GENERICACCOUNT, identifier=context.account_session.account_id)
raise NodeNotFoundError(
node_type=InfrahubKind.GENERICACCOUNT, identifier=context.active_account_session.account_id
)


async def default_resolver(*args: Any, **kwargs) -> dict | list[dict] | None:
async def default_resolver(*args: Any, **kwargs: dict[str, Any]) -> dict | list[dict] | None:
"""Not sure why but the default resolver returns sometime 4 positional args and sometime 2.

When it returns 4, they are organized as follow
Expand Down Expand Up @@ -126,7 +133,7 @@ async def default_paginated_list_resolver(
partial_match: bool = False,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
schema: MainSchemaTypes = info.return_type.graphene_type._meta.schema
schema = _return_object_type_schema(object_type=info.return_type)
fields = await extract_selection(info.field_nodes[0], schema=schema)

context: GraphqlContext = info.context
Expand Down Expand Up @@ -183,20 +190,22 @@ async def default_paginated_list_resolver(
return response


async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
async def single_relationship_resolver(
parent: dict, info: GraphQLResolveInfo, **kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Resolver for relationships of cardinality=one for Edged responses

This resolver is used for paginated responses and as such we redefined the requested
fields by only reusing information below the 'node' key.
"""
# Extract the InfraHub schema by inspecting the GQL Schema

node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema
node_schema = _return_requested_object_type_schema(object_type=info.parent_type, schema_type=NodeSchema)

context: GraphqlContext = info.context

# Extract the name of the fields in the GQL query
fields = await extract_fields(info.field_nodes[0].selection_set)
fields = await extract_fields_first_node(info)
node_fields = fields.get("node", {})
property_fields = fields.get("properties", {})
for key, value in property_fields.items():
Expand Down Expand Up @@ -242,20 +251,25 @@ async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, *


async def many_relationship_resolver(
parent: dict, info: GraphQLResolveInfo, include_descendants: Optional[bool] = False, **kwargs
parent: dict,
info: GraphQLResolveInfo,
include_descendants: Optional[bool] = False,
limit: int | None = None,
offset: int | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Resolver for relationships of cardinality=many for Edged responses

This resolver is used for paginated responses and as such we redefined the requested
fields by only reusing information below the 'node' key.
"""
# Extract the InfraHub schema by inspecting the GQL Schema
node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema
node_schema = _return_requested_object_type_schema(object_type=info.parent_type, schema_type=NodeSchema)

context: GraphqlContext = info.context

# Extract the name of the fields in the GQL query
fields = await extract_fields(info.field_nodes[0].selection_set)
fields = await extract_fields_first_node(info)
edges = fields.get("edges", {})
node_fields = edges.get("node", {})
property_fields = edges.get("properties", {})
Expand All @@ -266,10 +280,6 @@ async def many_relationship_resolver(
# Extract the schema of the node on the other end of the relationship from the GQL Schema
node_rel = node_schema.get_relationship(info.field_name)

# Extract only the filters from the kwargs and prepend the name of the field to the filters
offset = kwargs.pop("offset", None)
limit = kwargs.pop("limit", None)

filters = {
f"{info.field_name}__{key}": value
for key, value in kwargs.items()
Expand Down Expand Up @@ -335,7 +345,7 @@ async def many_relationship_resolver(

entries = []
for node in node_graph:
entry = {"node": {}, "properties": {}}
entry: dict[str, dict] = {"node": {}, "properties": {}}
for key, mapped in RELATIONS_PROPERTY_MAP_REVERSED.items():
value = node.pop(key, None)
if value:
Expand All @@ -347,40 +357,62 @@ async def many_relationship_resolver(
return response


async def ancestors_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
async def ancestors_resolver(
parent: dict,
info: GraphQLResolveInfo,
limit: int | None = None,
offset: int | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return await hierarchy_resolver(
direction=RelationshipHierarchyDirection.ANCESTORS, parent=parent, info=info, **kwargs
direction=RelationshipHierarchyDirection.ANCESTORS,
parent=parent,
info=info,
limit=limit,
offset=offset,
**kwargs,
)


async def descendants_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
async def descendants_resolver(
parent: dict,
info: GraphQLResolveInfo,
limit: int | None = None,
offset: int | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return await hierarchy_resolver(
direction=RelationshipHierarchyDirection.DESCENDANTS, parent=parent, info=info, **kwargs
direction=RelationshipHierarchyDirection.DESCENDANTS,
parent=parent,
info=info,
limit=limit,
offset=offset,
**kwargs,
)


async def hierarchy_resolver(
direction: RelationshipHierarchyDirection, parent: dict, info: GraphQLResolveInfo, **kwargs
direction: RelationshipHierarchyDirection,
parent: dict,
info: GraphQLResolveInfo,
limit: int | None = None,
offset: int | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Resolver for ancestors and dependants for Hierarchical nodes

This resolver is used for paginated responses and as such we redefined the requested
fields by only reusing information below the 'node' key.
"""
# Extract the InfraHub schema by inspecting the GQL Schema
node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema
node_schema = _return_requested_object_type_schema(object_type=info.parent_type, schema_type=NodeSchema)

context: GraphqlContext = info.context

# Extract the name of the fields in the GQL query
fields = await extract_fields(info.field_nodes[0].selection_set)
fields = await extract_fields_first_node(info)
edges = fields.get("edges", {})
node_fields = edges.get("node", {})

# Extract only the filters from the kwargs and prepend the name of the field to the filters
offset = kwargs.pop("offset", None)
limit = kwargs.pop("limit", None)

filters = {
f"{info.field_name}__{key}": value
for key, value in kwargs.items()
Expand Down Expand Up @@ -423,9 +455,25 @@ async def hierarchy_resolver(

entries = []
for node in node_graph:
entry = {"node": {}, "properties": {}}
entry: dict[str, dict] = {"node": {}, "properties": {}}
entry["node"] = node
entries.append(entry)
response["edges"] = entries

return response


def _return_object_type_schema(object_type: GraphQLObjectType | GraphQLOutputType) -> MainSchemaTypes:
infrahub_object: InfrahubObject = getattr(object_type, "graphene_type")
object_options: InfrahubObjectOptions = getattr(infrahub_object, "_meta")
return object_options.schema


def _return_requested_object_type_schema(
object_type: GraphQLObjectType | GraphQLOutputType, schema_type: type[SchemaType]
) -> SchemaType:
schema = _return_object_type_schema(object_type=object_type)
if isinstance(schema, schema_type):
return schema

raise TypeError("The object doesn't match the requested schema")
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,6 @@ ignore_errors = true
module = "infrahub.graphql.mutations.schema"
ignore_errors = true

[[tool.mypy.overrides]]
module = "infrahub.graphql.resolver"
ignore_errors = true

[[tool.mypy.overrides]]
module = "infrahub.graphql.subscription"
ignore_errors = true
Expand Down
Loading