diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..2a3ac6f238 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,51 @@ +Release type: minor + +You can now configure your schemas to provide a custom subclass of +`strawberry.types.Info` to your types and queries. + +```py +import strawberry +from strawberry.schema.config import StrawberryConfig + +from .models import ProductModel + + +class CustomInfo(strawberry.Info): + @property + def selected_group_id(self) -> int | None: + """Get the ID of the group you're logged in as.""" + return self.context["request"].headers.get("Group-ID") + + +@strawberry.type +class Group: + id: strawberry.ID + name: str + + +@strawberry.type +class User: + id: strawberry.ID + name: str + group: Group + + +@strawberry.type +class Query: + @strawberry.field + def user(self, id: strawberry.ID, info: CustomInfo) -> Product: + kwargs = {"id": id, "name": ...} + + if info.selected_group_id is not None: + # Get information about the group you're a part of, if + # available. + kwargs["group"] = ... + + return User(**kwargs) + + +schema = strawberry.Schema( + Query, + config=StrawberryConfig(info_class=CustomInfo), +) +``` diff --git a/docs/types/schema-configurations.md b/docs/types/schema-configurations.md index 099153238e..5ad12c9c6c 100644 --- a/docs/types/schema-configurations.md +++ b/docs/types/schema-configurations.md @@ -105,3 +105,23 @@ schema = strawberry.Schema( query=Query, config=StrawberryConfig(disable_field_suggestions=True) ) ``` + +### info_class + +By default Strawberry will create an object of type `strawberry.Info` when the +user defines `info: Info` as a parameter to a type or query. You can change this +behaviour by setting `info_class` to a subclass of `strawberry.Info`. + +This can be useful when you want to create a simpler interface for info- or +context-based properties, or if you wanted to attach additional properties to +the `Info` class. + +```python +class CustomInfo(Info): + @property + def response_headers(self) -> Headers: + return self.context["response"].headers + + +schema = strawberry.Schema(query=Query, info_class=CustomInfo) +``` diff --git a/strawberry/extensions/directives.py b/strawberry/extensions/directives.py index b72923adda..82f9efe146 100644 --- a/strawberry/extensions/directives.py +++ b/strawberry/extensions/directives.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple from strawberry.extensions import SchemaExtension -from strawberry.types import Info from strawberry.types.nodes import convert_arguments from strawberry.utils.await_maybe import await_maybe @@ -81,7 +80,9 @@ def process_directive( field_name=info.field_name, type_name=info.parent_type.name, ) - arguments[info_parameter.name] = Info(_raw_info=info, _field=field) + arguments[info_parameter.name] = schema.config.info_class( + _raw_info=info, _field=field + ) if value_parameter: arguments[value_parameter.name] = value return strawberry_directive, arguments diff --git a/strawberry/schema/config.py b/strawberry/schema/config.py index 2c5a37eb53..230ae7dc10 100644 --- a/strawberry/schema/config.py +++ b/strawberry/schema/config.py @@ -3,6 +3,8 @@ from dataclasses import InitVar, dataclass, field from typing import Any, Callable +from strawberry.types.info import Info + from .name_converter import NameConverter @@ -13,6 +15,7 @@ class StrawberryConfig: default_resolver: Callable[[Any, str], object] = getattr relay_max_results: int = 100 disable_field_suggestions: bool = False + info_class: type[Info] = Info def __post_init__( self, @@ -21,5 +24,8 @@ def __post_init__( if auto_camel_case is not None: self.name_converter.auto_camel_case = auto_camel_case + if not issubclass(self.info_class, Info): + raise TypeError("`info_class` must be a subclass of strawberry.Info") + __all__ = ["StrawberryConfig"] diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 9d19ad0345..1083b46f9b 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -64,7 +64,6 @@ ) from strawberry.types.enum import EnumDefinition from strawberry.types.field import UNRESOLVED -from strawberry.types.info import Info from strawberry.types.lazy_type import LazyType from strawberry.types.private import is_private from strawberry.types.scalar import ScalarWrapper @@ -90,6 +89,7 @@ from strawberry.schema_directive import StrawberrySchemaDirective from strawberry.types.enum import EnumValue from strawberry.types.field import StrawberryField + from strawberry.types.info import Info from strawberry.types.scalar import ScalarDefinition @@ -664,7 +664,7 @@ def _get_basic_result(_source: Any, *args: str, **kwargs: Any) -> Any: return _get_basic_result def _strawberry_info_from_graphql(info: GraphQLResolveInfo) -> Info: - return Info( + return self.config.info_class( _raw_info=info, _field=field, ) diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index eea9e0e1ab..a59eaa9938 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -10,6 +10,7 @@ FieldExtension, SyncExtensionResolver, ) +from strawberry.schema.config import StrawberryConfig class UpperCaseExtension(FieldExtension): @@ -380,3 +381,36 @@ def string( }, "another_input": {}, } + + +def test_extension_has_custom_info_class(): + class CustomInfo(strawberry.Info): + test: str = "foo" + + class CustomExtension(FieldExtension): + def resolve( + self, + next_: Callable[..., Any], + source: Any, + info: CustomInfo, + **kwargs: Any, + ): + assert isinstance(info, CustomInfo) + # Explicitly check it's not Info. + assert strawberry.Info in type(info).__bases__ + assert info.test == "foo" + return next_(source, info, **kwargs) + + @strawberry.type + class Query: + @strawberry.field(extensions=[CustomExtension()]) + def string(self) -> str: + return "This is a test!!" + + schema = strawberry.Schema( + query=Query, config=StrawberryConfig(info_class=CustomInfo) + ) + query = "query { string }" + result = schema.execute_sync(query) + assert result.data, result.errors + assert result.data["string"] == "This is a test!!" diff --git a/tests/schema/test_config.py b/tests/schema/test_config.py new file mode 100644 index 0000000000..665b772391 --- /dev/null +++ b/tests/schema/test_config.py @@ -0,0 +1,39 @@ +import pytest + +from strawberry.schema.config import StrawberryConfig +from strawberry.types.info import Info + + +def test_config_post_init_auto_camel_case(): + config = StrawberryConfig(auto_camel_case=True) + + assert config.name_converter.auto_camel_case is True + + +def test_config_post_init_no_auto_camel_case(): + config = StrawberryConfig(auto_camel_case=False) + + assert config.name_converter.auto_camel_case is False + + +def test_config_post_init_info_class(): + class CustomInfo(Info): + test: str = "foo" + + config = StrawberryConfig(info_class=CustomInfo) + + assert config.info_class is CustomInfo + assert config.info_class.test == "foo" + + +def test_config_post_init_info_class_is_default(): + config = StrawberryConfig() + + assert config.info_class is Info + + +def test_config_post_init_info_class_is_not_subclass(): + with pytest.raises(TypeError) as exc_info: + StrawberryConfig(info_class=object) + + assert str(exc_info.value) == "`info_class` must be a subclass of strawberry.Info" diff --git a/tests/schema/test_directives.py b/tests/schema/test_directives.py index 9c82d483e0..3b31e97258 100644 --- a/tests/schema/test_directives.py +++ b/tests/schema/test_directives.py @@ -5,6 +5,7 @@ import pytest import strawberry +from strawberry import Info from strawberry.directive import DirectiveLocation, DirectiveValue from strawberry.extensions import SchemaExtension from strawberry.schema.config import StrawberryConfig @@ -654,3 +655,36 @@ def uppercase(value: str, input: DirectiveInput): ''' assert schema.as_str() == textwrap.dedent(expected_schema).strip() + + +@pytest.mark.asyncio +async def test_directive_with_custom_info_class() -> NoReturn: + @strawberry.type + class Query: + @strawberry.field + def greeting(self) -> str: + return "Hi" + + class CustomInfo(Info): + test: str = "foo" + + @strawberry.directive(locations=[DirectiveLocation.FIELD]) + def append_names(value: DirectiveValue[str], names: List[str], info: CustomInfo): + assert isinstance(names, list) + assert isinstance(info, CustomInfo) + assert Info in type(info).__bases__ # Explicitly check it's not Info. + assert info.test == "foo" + return f"{value} {', '.join(names)}" + + schema = strawberry.Schema( + query=Query, + directives=[append_names], + config=StrawberryConfig(info_class=CustomInfo), + ) + + result = await schema.execute( + 'query { greeting @appendNames(names: ["foo", "bar"])}' + ) + + assert result.errors is None + assert result.data["greeting"] == "Hi foo, bar"