Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set a custom info class for a schema #3592

Merged
merged 14 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Release type: minor

You can now configure your schemas to provide a custom subclass of
`strawberry.types.Info` to your types and queries.

```py
import strawberry
from strawberry.schema.config import StrawberryConfig

from .models import ProductModel


class CustomInfo(strawberry.Info):
@property
def selected_group_id(self) -> int | None:
"""Get the ID of the group you're logged in as."""
return self.context["request"].headers.get("Group-ID")


@strawberry.type
class Group:
id: strawberry.ID
name: str


@strawberry.type
class User:
id: strawberry.ID
name: str
group: Group


@strawberry.type
class Query:
@strawberry.field
def user(self, id: strawberry.ID, info: CustomInfo) -> Product:
kwargs = {"id": id, "name": ...}

if info.selected_group_id is not None:
# Get information about the group you're a part of, if
# available.
kwargs["group"] = ...

return User(**kwargs)


schema = strawberry.Schema(
Query,
config=StrawberryConfig(info_class=CustomInfo),
)
```
20 changes: 20 additions & 0 deletions docs/types/schema-configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,23 @@ schema = strawberry.Schema(
query=Query, config=StrawberryConfig(disable_field_suggestions=True)
)
```

### info_class

By default Strawberry will create an object of type `strawberry.Info` when the
user defines `info: Info` as a parameter to a type or query. You can change this
behaviour by setting `info_class` to a subclass of `strawberry.Info`.

This can be useful when you want to create a simpler interface for info- or
context-based properties, or if you wanted to attach additional properties to
the `Info` class.

```python
class CustomInfo(Info):
@property
def response_headers(self) -> Headers:
return self.context["response"].headers


schema = strawberry.Schema(query=Query, info_class=CustomInfo)
```
5 changes: 3 additions & 2 deletions strawberry/extensions/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple

from strawberry.extensions import SchemaExtension
from strawberry.types import Info
from strawberry.types.nodes import convert_arguments
from strawberry.utils.await_maybe import await_maybe

Expand Down Expand Up @@ -81,7 +80,9 @@ def process_directive(
field_name=info.field_name,
type_name=info.parent_type.name,
)
arguments[info_parameter.name] = Info(_raw_info=info, _field=field)
arguments[info_parameter.name] = schema.config.info_class(
_raw_info=info, _field=field
)
if value_parameter:
arguments[value_parameter.name] = value
return strawberry_directive, arguments
Expand Down
6 changes: 6 additions & 0 deletions strawberry/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Callable

from strawberry.types.info import Info

from .name_converter import NameConverter


Expand All @@ -13,6 +15,7 @@ class StrawberryConfig:
default_resolver: Callable[[Any, str], object] = getattr
relay_max_results: int = 100
disable_field_suggestions: bool = False
info_class: type[Info] = Info

def __post_init__(
self,
Expand All @@ -21,5 +24,8 @@ def __post_init__(
if auto_camel_case is not None:
self.name_converter.auto_camel_case = auto_camel_case

if not issubclass(self.info_class, Info):
raise TypeError("`info_class` must be a subclass of strawberry.Info")


__all__ = ["StrawberryConfig"]
4 changes: 2 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
)
from strawberry.types.enum import EnumDefinition
from strawberry.types.field import UNRESOLVED
from strawberry.types.info import Info
from strawberry.types.lazy_type import LazyType
from strawberry.types.private import is_private
from strawberry.types.scalar import ScalarWrapper
Expand All @@ -90,6 +89,7 @@
from strawberry.schema_directive import StrawberrySchemaDirective
from strawberry.types.enum import EnumValue
from strawberry.types.field import StrawberryField
from strawberry.types.info import Info
from strawberry.types.scalar import ScalarDefinition


Expand Down Expand Up @@ -664,7 +664,7 @@ def _get_basic_result(_source: Any, *args: str, **kwargs: Any) -> Any:
return _get_basic_result

def _strawberry_info_from_graphql(info: GraphQLResolveInfo) -> Info:
return Info(
return self.config.info_class(
_raw_info=info,
_field=field,
)
Expand Down
34 changes: 34 additions & 0 deletions tests/schema/extensions/test_field_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FieldExtension,
SyncExtensionResolver,
)
from strawberry.schema.config import StrawberryConfig


class UpperCaseExtension(FieldExtension):
Expand Down Expand Up @@ -380,3 +381,36 @@ def string(
},
"another_input": {},
}


def test_extension_has_custom_info_class():
class CustomInfo(strawberry.Info):
test: str = "foo"

class CustomExtension(FieldExtension):
def resolve(
self,
next_: Callable[..., Any],
source: Any,
info: CustomInfo,
**kwargs: Any,
):
assert isinstance(info, CustomInfo)
# Explicitly check it's not Info.
assert strawberry.Info in type(info).__bases__
assert info.test == "foo"
return next_(source, info, **kwargs)

@strawberry.type
class Query:
@strawberry.field(extensions=[CustomExtension()])
def string(self) -> str:
return "This is a test!!"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(info_class=CustomInfo)
)
query = "query { string }"
result = schema.execute_sync(query)
assert result.data, result.errors
assert result.data["string"] == "This is a test!!"
39 changes: 39 additions & 0 deletions tests/schema/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from strawberry.schema.config import StrawberryConfig
from strawberry.types.info import Info


def test_config_post_init_auto_camel_case():
config = StrawberryConfig(auto_camel_case=True)

assert config.name_converter.auto_camel_case is True


def test_config_post_init_no_auto_camel_case():
config = StrawberryConfig(auto_camel_case=False)

assert config.name_converter.auto_camel_case is False


def test_config_post_init_info_class():
class CustomInfo(Info):
test: str = "foo"

config = StrawberryConfig(info_class=CustomInfo)

assert config.info_class is CustomInfo
assert config.info_class.test == "foo"


def test_config_post_init_info_class_is_default():
config = StrawberryConfig()

assert config.info_class is Info


def test_config_post_init_info_class_is_not_subclass():
with pytest.raises(TypeError) as exc_info:
StrawberryConfig(info_class=object)

assert str(exc_info.value) == "`info_class` must be a subclass of strawberry.Info"
34 changes: 34 additions & 0 deletions tests/schema/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import strawberry
from strawberry import Info
from strawberry.directive import DirectiveLocation, DirectiveValue
from strawberry.extensions import SchemaExtension
from strawberry.schema.config import StrawberryConfig
Expand Down Expand Up @@ -654,3 +655,36 @@ def uppercase(value: str, input: DirectiveInput):
'''

assert schema.as_str() == textwrap.dedent(expected_schema).strip()


@pytest.mark.asyncio
async def test_directive_with_custom_info_class() -> NoReturn:
@strawberry.type
class Query:
@strawberry.field
def greeting(self) -> str:
return "Hi"

class CustomInfo(Info):
test: str = "foo"

@strawberry.directive(locations=[DirectiveLocation.FIELD])
def append_names(value: DirectiveValue[str], names: List[str], info: CustomInfo):
assert isinstance(names, list)
assert isinstance(info, CustomInfo)
assert Info in type(info).__bases__ # Explicitly check it's not Info.
assert info.test == "foo"
return f"{value} {', '.join(names)}"

schema = strawberry.Schema(
query=Query,
directives=[append_names],
config=StrawberryConfig(info_class=CustomInfo),
)

result = await schema.execute(
'query { greeting @appendNames(names: ["foo", "bar"])}'
)

assert result.errors is None
assert result.data["greeting"] == "Hi foo, bar"
Loading