Skip to content

Commit

Permalink
Add ability to set a custom info class for a schema (#3592)
Browse files Browse the repository at this point in the history
* Add ability to use custom info objects schema-wide

* Add RELEASE.md

* Improve RELEASE.md example

* Use 'Info' class directly instead of default factory for dataclass field

* Add no cover pragma + change to TypeError over ValueError

* Update the RELEASE example

* Update documentation to make note of the new 'info_class' config option

* Use 'strawberry.Info' instead of 'strawberry.types.Info' in info_class docs

Co-authored-by: Jonathan Ehwald <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add __post_init__ tests for StrawberryConfig

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Allow for the use of custom info classes within directives

* Add testing to ensure custom info classes are included in directives and extensions

---------

Co-authored-by: Jonathan Ehwald <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent 13bd97b commit 7287047
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 4 deletions.
51 changes: 51 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Release type: minor

You can now configure your schemas to provide a custom subclass of
`strawberry.types.Info` to your types and queries.

```py
import strawberry
from strawberry.schema.config import StrawberryConfig

from .models import ProductModel


class CustomInfo(strawberry.Info):
@property
def selected_group_id(self) -> int | None:
"""Get the ID of the group you're logged in as."""
return self.context["request"].headers.get("Group-ID")


@strawberry.type
class Group:
id: strawberry.ID
name: str


@strawberry.type
class User:
id: strawberry.ID
name: str
group: Group


@strawberry.type
class Query:
@strawberry.field
def user(self, id: strawberry.ID, info: CustomInfo) -> Product:
kwargs = {"id": id, "name": ...}

if info.selected_group_id is not None:
# Get information about the group you're a part of, if
# available.
kwargs["group"] = ...

return User(**kwargs)


schema = strawberry.Schema(
Query,
config=StrawberryConfig(info_class=CustomInfo),
)
```
20 changes: 20 additions & 0 deletions docs/types/schema-configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,23 @@ schema = strawberry.Schema(
query=Query, config=StrawberryConfig(disable_field_suggestions=True)
)
```

### info_class

By default Strawberry will create an object of type `strawberry.Info` when the
user defines `info: Info` as a parameter to a type or query. You can change this
behaviour by setting `info_class` to a subclass of `strawberry.Info`.

This can be useful when you want to create a simpler interface for info- or
context-based properties, or if you wanted to attach additional properties to
the `Info` class.

```python
class CustomInfo(Info):
@property
def response_headers(self) -> Headers:
return self.context["response"].headers


schema = strawberry.Schema(query=Query, info_class=CustomInfo)
```
5 changes: 3 additions & 2 deletions strawberry/extensions/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple

from strawberry.extensions import SchemaExtension
from strawberry.types import Info
from strawberry.types.nodes import convert_arguments
from strawberry.utils.await_maybe import await_maybe

Expand Down Expand Up @@ -81,7 +80,9 @@ def process_directive(
field_name=info.field_name,
type_name=info.parent_type.name,
)
arguments[info_parameter.name] = Info(_raw_info=info, _field=field)
arguments[info_parameter.name] = schema.config.info_class(
_raw_info=info, _field=field
)
if value_parameter:
arguments[value_parameter.name] = value
return strawberry_directive, arguments
Expand Down
6 changes: 6 additions & 0 deletions strawberry/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Callable

from strawberry.types.info import Info

from .name_converter import NameConverter


Expand All @@ -13,6 +15,7 @@ class StrawberryConfig:
default_resolver: Callable[[Any, str], object] = getattr
relay_max_results: int = 100
disable_field_suggestions: bool = False
info_class: type[Info] = Info

def __post_init__(
self,
Expand All @@ -21,5 +24,8 @@ def __post_init__(
if auto_camel_case is not None:
self.name_converter.auto_camel_case = auto_camel_case

if not issubclass(self.info_class, Info):
raise TypeError("`info_class` must be a subclass of strawberry.Info")


__all__ = ["StrawberryConfig"]
4 changes: 2 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
)
from strawberry.types.enum import EnumDefinition
from strawberry.types.field import UNRESOLVED
from strawberry.types.info import Info
from strawberry.types.lazy_type import LazyType
from strawberry.types.private import is_private
from strawberry.types.scalar import ScalarWrapper
Expand All @@ -90,6 +89,7 @@
from strawberry.schema_directive import StrawberrySchemaDirective
from strawberry.types.enum import EnumValue
from strawberry.types.field import StrawberryField
from strawberry.types.info import Info
from strawberry.types.scalar import ScalarDefinition


Expand Down Expand Up @@ -664,7 +664,7 @@ def _get_basic_result(_source: Any, *args: str, **kwargs: Any) -> Any:
return _get_basic_result

def _strawberry_info_from_graphql(info: GraphQLResolveInfo) -> Info:
return Info(
return self.config.info_class(
_raw_info=info,
_field=field,
)
Expand Down
34 changes: 34 additions & 0 deletions tests/schema/extensions/test_field_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FieldExtension,
SyncExtensionResolver,
)
from strawberry.schema.config import StrawberryConfig


class UpperCaseExtension(FieldExtension):
Expand Down Expand Up @@ -380,3 +381,36 @@ def string(
},
"another_input": {},
}


def test_extension_has_custom_info_class():
class CustomInfo(strawberry.Info):
test: str = "foo"

class CustomExtension(FieldExtension):
def resolve(
self,
next_: Callable[..., Any],
source: Any,
info: CustomInfo,
**kwargs: Any,
):
assert isinstance(info, CustomInfo)
# Explicitly check it's not Info.
assert strawberry.Info in type(info).__bases__
assert info.test == "foo"
return next_(source, info, **kwargs)

@strawberry.type
class Query:
@strawberry.field(extensions=[CustomExtension()])
def string(self) -> str:
return "This is a test!!"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(info_class=CustomInfo)
)
query = "query { string }"
result = schema.execute_sync(query)
assert result.data, result.errors
assert result.data["string"] == "This is a test!!"
39 changes: 39 additions & 0 deletions tests/schema/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from strawberry.schema.config import StrawberryConfig
from strawberry.types.info import Info


def test_config_post_init_auto_camel_case():
config = StrawberryConfig(auto_camel_case=True)

assert config.name_converter.auto_camel_case is True


def test_config_post_init_no_auto_camel_case():
config = StrawberryConfig(auto_camel_case=False)

assert config.name_converter.auto_camel_case is False


def test_config_post_init_info_class():
class CustomInfo(Info):
test: str = "foo"

config = StrawberryConfig(info_class=CustomInfo)

assert config.info_class is CustomInfo
assert config.info_class.test == "foo"


def test_config_post_init_info_class_is_default():
config = StrawberryConfig()

assert config.info_class is Info


def test_config_post_init_info_class_is_not_subclass():
with pytest.raises(TypeError) as exc_info:
StrawberryConfig(info_class=object)

assert str(exc_info.value) == "`info_class` must be a subclass of strawberry.Info"
34 changes: 34 additions & 0 deletions tests/schema/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import strawberry
from strawberry import Info
from strawberry.directive import DirectiveLocation, DirectiveValue
from strawberry.extensions import SchemaExtension
from strawberry.schema.config import StrawberryConfig
Expand Down Expand Up @@ -654,3 +655,36 @@ def uppercase(value: str, input: DirectiveInput):
'''

assert schema.as_str() == textwrap.dedent(expected_schema).strip()


@pytest.mark.asyncio
async def test_directive_with_custom_info_class() -> NoReturn:
@strawberry.type
class Query:
@strawberry.field
def greeting(self) -> str:
return "Hi"

class CustomInfo(Info):
test: str = "foo"

@strawberry.directive(locations=[DirectiveLocation.FIELD])
def append_names(value: DirectiveValue[str], names: List[str], info: CustomInfo):
assert isinstance(names, list)
assert isinstance(info, CustomInfo)
assert Info in type(info).__bases__ # Explicitly check it's not Info.
assert info.test == "foo"
return f"{value} {', '.join(names)}"

schema = strawberry.Schema(
query=Query,
directives=[append_names],
config=StrawberryConfig(info_class=CustomInfo),
)

result = await schema.execute(
'query { greeting @appendNames(names: ["foo", "bar"])}'
)

assert result.errors is None
assert result.data["greeting"] == "Hi foo, bar"

0 comments on commit 7287047

Please sign in to comment.