Skip to content

Commit

Permalink
fix: Prevent a possible security issue when resolving a relay node wi…
Browse files Browse the repository at this point in the history
…th multiple possibilities (#3749)
  • Loading branch information
bellini666 authored Jan 9, 2025
1 parent fc854f1 commit 526eb82
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 3 deletions.
21 changes: 21 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Release type: minor

The common `node: Node` used to resolve relay nodes means we will be relying on
is_type_of to check if the returned object is in fact a subclass of the Node
interface.

However, integrations such as Django, SQLAlchemy and Pydantic will not return
the type itself, but instead an alike object that is later resolved to the
expected type.

In case there are more than one possible type defined for that model that is
being returned, the first one that replies True to `is_type_of` check would be
used in the resolution, meaning that when asking for `"PublicUser:123"`,
strawberry could end up returning `"User:123"`, which can lead to security
issues (such as data leakage).

In here we are introducing a new `strawberry.cast`, which will be used to mark
an object with the already known type by us, and when asking for is_type_of that
mark will be used to check instead, ensuring we will return the correct type.

That `cast` is already in place for the relay node resolution and pydantic.
2 changes: 2 additions & 0 deletions strawberry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .schema_directive import schema_directive
from .types.arguments import argument
from .types.auto import auto
from .types.cast import cast
from .types.enum import enum, enum_value
from .types.field import field
from .types.info import Info
Expand All @@ -36,6 +37,7 @@
"argument",
"asdict",
"auto",
"cast",
"directive",
"directive_field",
"enum",
Expand Down
4 changes: 4 additions & 0 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_private_fields,
)
from strawberry.types.auto import StrawberryAuto
from strawberry.types.cast import get_strawberry_type_cast
from strawberry.types.field import StrawberryField
from strawberry.types.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_resolver import _get_fields
Expand Down Expand Up @@ -207,6 +208,9 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
# pydantic objects (not the corresponding strawberry type)
@classmethod # type: ignore
def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool:
if (type_cast := get_strawberry_type_cast(obj)) is not None:
return type_cast is cls

return isinstance(obj, (cls, model))

namespace = {"is_type_of": is_type_of}
Expand Down
31 changes: 28 additions & 3 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from strawberry.types.arguments import StrawberryArgument, argument
from strawberry.types.base import StrawberryList, StrawberryOptional
from strawberry.types.cast import cast as strawberry_cast
from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.lazy_type import LazyType
Expand Down Expand Up @@ -88,12 +89,27 @@ def resolver(
info: Info,
id: Annotated[GlobalID, argument(description="The ID of the object.")],
) -> Union[Node, None, Awaitable[Union[Node, None]]]:
return id.resolve_type(info).resolve_node(
node_type = id.resolve_type(info)
resolved_node = node_type.resolve_node(
id.node_id,
info=info,
required=not is_optional,
)

# We are using `strawberry_cast` here to cast the resolved node to make
# sure `is_type_of` will not try to find its type again. Very important
# when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
# we could end up resolving to a different type in case more than one
# are registered.
if inspect.isawaitable(resolved_node):

async def resolve() -> Any:
return strawberry_cast(node_type, await resolved_node)

return resolve()

return cast(Node, strawberry_cast(node_type, resolved_node))

return resolver

def get_node_list_resolver(
Expand Down Expand Up @@ -139,6 +155,14 @@ def resolver(
if inspect.isasyncgen(nodes)
}

# We are using `strawberry_cast` here to cast the resolved node to make
# sure `is_type_of` will not try to find its type again. Very important
# when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
# we could end up resolving to a different type in case more than one
# are registered
def cast_nodes(node_t: type[Node], nodes: Iterable[Any]) -> list[Node]:
return [cast(Node, strawberry_cast(node_t, node)) for node in nodes]

if awaitable_nodes or asyncgen_nodes:

async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
Expand All @@ -161,7 +185,8 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:

# Resolve any generator to lists
resolved = {
node_t: list(nodes) for node_t, nodes in resolved.items()
node_t: cast_nodes(node_t, nodes)
for node_t, nodes in resolved.items()
}
return [
resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids
Expand All @@ -171,7 +196,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:

# Resolve any generator to lists
resolved = {
node_t: list(cast(Iterator[Node], nodes))
node_t: cast_nodes(node_t, cast(Iterable[Node], nodes))
for node_t, nodes in resolved_nodes.items()
}
return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids]
Expand Down
7 changes: 7 additions & 0 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
get_object_definition,
has_object_definition,
)
from strawberry.types.cast import get_strawberry_type_cast
from strawberry.types.enum import EnumDefinition
from strawberry.types.field import UNRESOLVED
from strawberry.types.lazy_type import LazyType
Expand Down Expand Up @@ -619,6 +620,9 @@ def _get_is_type_of() -> Optional[Callable[[Any, GraphQLResolveInfo], bool]]:
)

def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
if (type_cast := get_strawberry_type_cast(obj)) is not None:
return type_cast in possible_types

if object_type.concrete_of and (
has_object_definition(obj)
and obj.__strawberry_definition__.origin
Expand Down Expand Up @@ -898,6 +902,9 @@ def _get_is_type_of(
if object_type.interfaces:

def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
if (type_cast := get_strawberry_type_cast(obj)) is not None:
return type_cast is object_type.origin

if object_type.concrete_of and (
has_object_definition(obj)
and obj.__strawberry_definition__.origin
Expand Down
35 changes: 35 additions & 0 deletions strawberry/types/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from typing import Any, TypeVar, overload

_T = TypeVar("_T", bound=object)

TYPE_CAST_ATTRIBUTE = "__as_strawberry_type__"


@overload
def cast(type_: type, obj: None) -> None: ...


@overload
def cast(type_: type, obj: _T) -> _T: ...


def cast(type_: type, obj: _T | None) -> _T | None:
"""Cast an object to given type.
This is used to mark an object as a cast object, so that the type can be
picked up when resolving unions/interfaces in case of ambiguity, which can
happen when returning an alike object instead of an instance of the type
(e.g. returning a Django, Pydantic or SQLAlchemy object)
"""
if obj is None:
return None

setattr(obj, TYPE_CAST_ATTRIBUTE, type_)
return obj


def get_strawberry_type_cast(obj: Any) -> type | None:
"""Get the type of a cast object."""
return getattr(obj, TYPE_CAST_ATTRIBUTE, None)
78 changes: 78 additions & 0 deletions tests/relay/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import dataclasses
import textwrap
from collections.abc import Iterable
from typing import Optional, Union
from typing_extensions import Self

import pytest
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -1621,3 +1625,77 @@ def test_query_after_error():

assert result.errors is not None
assert "Argument 'after' contains a non-existing value" in str(result.errors)


@pytest.mark.parametrize(
("type_name", "should_have_name"),
[("Fruit", False), ("PublicFruit", True)],
)
@pytest.mark.django_db(transaction=True)
def test_correct_model_returned(type_name: str, should_have_name: bool):
@dataclasses.dataclass
class FruitModel:
id: str
name: str

fruits: dict[str, FruitModel] = {"1": FruitModel(id="1", name="Strawberry")}

@strawberry.type
class Fruit(relay.Node):
id: relay.NodeID[int]

@classmethod
def resolve_nodes(
cls,
*,
info: Optional[strawberry.Info] = None,
node_ids: Iterable[str],
required: bool = False,
) -> Iterable[Optional[Union[Self, FruitModel]]]:
return [fruits[nid] if required else fruits.get(nid) for nid in node_ids]

@strawberry.type
class PublicFruit(relay.Node):
id: relay.NodeID[int]
name: str

@classmethod
def resolve_nodes(
cls,
*,
info: Optional[strawberry.Info] = None,
node_ids: Iterable[str],
required: bool = False,
) -> Iterable[Optional[Union[Self, FruitModel]]]:
return [fruits[nid] if required else fruits.get(nid) for nid in node_ids]

@strawberry.type
class Query:
node: relay.Node = relay.node()

schema = strawberry.Schema(query=Query, types=[Fruit, PublicFruit])

node_id = relay.to_base64(type_name, "1")
result = schema.execute_sync(
"""
query NodeQuery($id: GlobalID!) {
node(id: $id) {
__typename
id
... on PublicFruit {
name
}
}
}
""",
{"id": node_id},
)
assert result.errors is None
assert isinstance(result.data, dict)

assert result.data["node"]["__typename"] == type_name
assert result.data["node"]["id"] == node_id
if should_have_name:
assert result.data["node"]["name"] == "Strawberry"
else:
assert "name" not in result.data["node"]
28 changes: 28 additions & 0 deletions tests/types/test_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import strawberry
from strawberry.types.cast import get_strawberry_type_cast


def test_cast():
@strawberry.type
class SomeType: ...

class OtherType: ...

obj = OtherType
assert get_strawberry_type_cast(obj) is None

cast_obj = strawberry.cast(SomeType, obj)
assert cast_obj is obj
assert get_strawberry_type_cast(cast_obj) is SomeType


def test_cast_none_obj():
@strawberry.type
class SomeType: ...

obj = None
assert get_strawberry_type_cast(obj) is None

cast_obj = strawberry.cast(SomeType, obj)
assert cast_obj is None
assert get_strawberry_type_cast(obj) is None

0 comments on commit 526eb82

Please sign in to comment.