From 526eb82b70451c0e59d5a71ae9b7396f59974bd8 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 9 Jan 2025 19:29:13 +0100 Subject: [PATCH] fix: Prevent a possible security issue when resolving a relay node with multiple possibilities (#3749) --- RELEASE.md | 21 +++++ strawberry/__init__.py | 2 + .../experimental/pydantic/object_type.py | 4 + strawberry/relay/fields.py | 31 +++++++- strawberry/schema/schema_converter.py | 7 ++ strawberry/types/cast.py | 35 +++++++++ tests/relay/test_fields.py | 78 +++++++++++++++++++ tests/types/test_cast.py | 28 +++++++ 8 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 RELEASE.md create mode 100644 strawberry/types/cast.py create mode 100644 tests/types/test_cast.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..c28e496da1 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,21 @@ +Release type: minor + +The common `node: Node` used to resolve relay nodes means we will be relying on +is_type_of to check if the returned object is in fact a subclass of the Node +interface. + +However, integrations such as Django, SQLAlchemy and Pydantic will not return +the type itself, but instead an alike object that is later resolved to the +expected type. + +In case there are more than one possible type defined for that model that is +being returned, the first one that replies True to `is_type_of` check would be +used in the resolution, meaning that when asking for `"PublicUser:123"`, +strawberry could end up returning `"User:123"`, which can lead to security +issues (such as data leakage). + +In here we are introducing a new `strawberry.cast`, which will be used to mark +an object with the already known type by us, and when asking for is_type_of that +mark will be used to check instead, ensuring we will return the correct type. + +That `cast` is already in place for the relay node resolution and pydantic. diff --git a/strawberry/__init__.py b/strawberry/__init__.py index d03f15d998..c1365b374f 100644 --- a/strawberry/__init__.py +++ b/strawberry/__init__.py @@ -13,6 +13,7 @@ from .schema_directive import schema_directive from .types.arguments import argument from .types.auto import auto +from .types.cast import cast from .types.enum import enum, enum_value from .types.field import field from .types.info import Info @@ -36,6 +37,7 @@ "argument", "asdict", "auto", + "cast", "directive", "directive_field", "enum", diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index ac8958a4aa..42216b8d02 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -29,6 +29,7 @@ get_private_fields, ) from strawberry.types.auto import StrawberryAuto +from strawberry.types.cast import get_strawberry_type_cast from strawberry.types.field import StrawberryField from strawberry.types.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields @@ -207,6 +208,9 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: # pydantic objects (not the corresponding strawberry type) @classmethod # type: ignore def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool: + if (type_cast := get_strawberry_type_cast(obj)) is not None: + return type_cast is cls + return isinstance(obj, (cls, model)) namespace = {"is_type_of": is_type_of} diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 347fd22169..3e0c77e240 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -37,6 +37,7 @@ ) from strawberry.types.arguments import StrawberryArgument, argument from strawberry.types.base import StrawberryList, StrawberryOptional +from strawberry.types.cast import cast as strawberry_cast from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.lazy_type import LazyType @@ -88,12 +89,27 @@ def resolver( info: Info, id: Annotated[GlobalID, argument(description="The ID of the object.")], ) -> Union[Node, None, Awaitable[Union[Node, None]]]: - return id.resolve_type(info).resolve_node( + node_type = id.resolve_type(info) + resolved_node = node_type.resolve_node( id.node_id, info=info, required=not is_optional, ) + # We are using `strawberry_cast` here to cast the resolved node to make + # sure `is_type_of` will not try to find its type again. Very important + # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as + # we could end up resolving to a different type in case more than one + # are registered. + if inspect.isawaitable(resolved_node): + + async def resolve() -> Any: + return strawberry_cast(node_type, await resolved_node) + + return resolve() + + return cast(Node, strawberry_cast(node_type, resolved_node)) + return resolver def get_node_list_resolver( @@ -139,6 +155,14 @@ def resolver( if inspect.isasyncgen(nodes) } + # We are using `strawberry_cast` here to cast the resolved node to make + # sure `is_type_of` will not try to find its type again. Very important + # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as + # we could end up resolving to a different type in case more than one + # are registered + def cast_nodes(node_t: type[Node], nodes: Iterable[Any]) -> list[Node]: + return [cast(Node, strawberry_cast(node_t, node)) for node in nodes] + if awaitable_nodes or asyncgen_nodes: async def resolve(resolved: Any = resolved_nodes) -> list[Node]: @@ -161,7 +185,8 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]: # Resolve any generator to lists resolved = { - node_t: list(nodes) for node_t, nodes in resolved.items() + node_t: cast_nodes(node_t, nodes) + for node_t, nodes in resolved.items() } return [ resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids @@ -171,7 +196,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]: # Resolve any generator to lists resolved = { - node_t: list(cast(Iterator[Node], nodes)) + node_t: cast_nodes(node_t, cast(Iterable[Node], nodes)) for node_t, nodes in resolved_nodes.items() } return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids] diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 2648aa6ad9..000d384934 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -58,6 +58,7 @@ get_object_definition, has_object_definition, ) +from strawberry.types.cast import get_strawberry_type_cast from strawberry.types.enum import EnumDefinition from strawberry.types.field import UNRESOLVED from strawberry.types.lazy_type import LazyType @@ -619,6 +620,9 @@ def _get_is_type_of() -> Optional[Callable[[Any, GraphQLResolveInfo], bool]]: ) def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool: + if (type_cast := get_strawberry_type_cast(obj)) is not None: + return type_cast in possible_types + if object_type.concrete_of and ( has_object_definition(obj) and obj.__strawberry_definition__.origin @@ -898,6 +902,9 @@ def _get_is_type_of( if object_type.interfaces: def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool: + if (type_cast := get_strawberry_type_cast(obj)) is not None: + return type_cast is object_type.origin + if object_type.concrete_of and ( has_object_definition(obj) and obj.__strawberry_definition__.origin diff --git a/strawberry/types/cast.py b/strawberry/types/cast.py new file mode 100644 index 0000000000..0cf903beaf --- /dev/null +++ b/strawberry/types/cast.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Any, TypeVar, overload + +_T = TypeVar("_T", bound=object) + +TYPE_CAST_ATTRIBUTE = "__as_strawberry_type__" + + +@overload +def cast(type_: type, obj: None) -> None: ... + + +@overload +def cast(type_: type, obj: _T) -> _T: ... + + +def cast(type_: type, obj: _T | None) -> _T | None: + """Cast an object to given type. + + This is used to mark an object as a cast object, so that the type can be + picked up when resolving unions/interfaces in case of ambiguity, which can + happen when returning an alike object instead of an instance of the type + (e.g. returning a Django, Pydantic or SQLAlchemy object) + """ + if obj is None: + return None + + setattr(obj, TYPE_CAST_ATTRIBUTE, type_) + return obj + + +def get_strawberry_type_cast(obj: Any) -> type | None: + """Get the type of a cast object.""" + return getattr(obj, TYPE_CAST_ATTRIBUTE, None) diff --git a/tests/relay/test_fields.py b/tests/relay/test_fields.py index 2ef08c3713..957e4b0795 100644 --- a/tests/relay/test_fields.py +++ b/tests/relay/test_fields.py @@ -1,4 +1,8 @@ +import dataclasses import textwrap +from collections.abc import Iterable +from typing import Optional, Union +from typing_extensions import Self import pytest from pytest_mock import MockerFixture @@ -1621,3 +1625,77 @@ def test_query_after_error(): assert result.errors is not None assert "Argument 'after' contains a non-existing value" in str(result.errors) + + +@pytest.mark.parametrize( + ("type_name", "should_have_name"), + [("Fruit", False), ("PublicFruit", True)], +) +@pytest.mark.django_db(transaction=True) +def test_correct_model_returned(type_name: str, should_have_name: bool): + @dataclasses.dataclass + class FruitModel: + id: str + name: str + + fruits: dict[str, FruitModel] = {"1": FruitModel(id="1", name="Strawberry")} + + @strawberry.type + class Fruit(relay.Node): + id: relay.NodeID[int] + + @classmethod + def resolve_nodes( + cls, + *, + info: Optional[strawberry.Info] = None, + node_ids: Iterable[str], + required: bool = False, + ) -> Iterable[Optional[Union[Self, FruitModel]]]: + return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] + + @strawberry.type + class PublicFruit(relay.Node): + id: relay.NodeID[int] + name: str + + @classmethod + def resolve_nodes( + cls, + *, + info: Optional[strawberry.Info] = None, + node_ids: Iterable[str], + required: bool = False, + ) -> Iterable[Optional[Union[Self, FruitModel]]]: + return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] + + @strawberry.type + class Query: + node: relay.Node = relay.node() + + schema = strawberry.Schema(query=Query, types=[Fruit, PublicFruit]) + + node_id = relay.to_base64(type_name, "1") + result = schema.execute_sync( + """ + query NodeQuery($id: GlobalID!) { + node(id: $id) { + __typename + id + ... on PublicFruit { + name + } + } + } + """, + {"id": node_id}, + ) + assert result.errors is None + assert isinstance(result.data, dict) + + assert result.data["node"]["__typename"] == type_name + assert result.data["node"]["id"] == node_id + if should_have_name: + assert result.data["node"]["name"] == "Strawberry" + else: + assert "name" not in result.data["node"] diff --git a/tests/types/test_cast.py b/tests/types/test_cast.py new file mode 100644 index 0000000000..6721676cde --- /dev/null +++ b/tests/types/test_cast.py @@ -0,0 +1,28 @@ +import strawberry +from strawberry.types.cast import get_strawberry_type_cast + + +def test_cast(): + @strawberry.type + class SomeType: ... + + class OtherType: ... + + obj = OtherType + assert get_strawberry_type_cast(obj) is None + + cast_obj = strawberry.cast(SomeType, obj) + assert cast_obj is obj + assert get_strawberry_type_cast(cast_obj) is SomeType + + +def test_cast_none_obj(): + @strawberry.type + class SomeType: ... + + obj = None + assert get_strawberry_type_cast(obj) is None + + cast_obj = strawberry.cast(SomeType, obj) + assert cast_obj is None + assert get_strawberry_type_cast(obj) is None