Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add recursion guard (Sourcery refactored) #469

Closed
wants to merge 14 commits into from
11 changes: 11 additions & 0 deletions docs/examples/library_factories/sqlalchemy_factory/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from collections.abc import Iterable

import pytest

from docs.examples.library_factories.sqlalchemy_factory.test_example_4 import BaseFactory


@pytest.fixture(scope="module")
def _remove_default_factories() -> Iterable[None]:
yield
BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List

from sqlalchemy import ForeignKey, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory, T


class Base(DeclarativeBase):
...


class Author(Base):
__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

books: Mapped[List["Book"]] = relationship(
"Book",
uselist=True,
back_populates="author",
)


class Book(Base):
__tablename__ = "books"

id: Mapped[int] = mapped_column(primary_key=True)
author_id: Mapped[int] = mapped_column(ForeignKey(Author.id), nullable=False)
author: Mapped[Author] = relationship(
"Author",
uselist=False,
back_populates="books",
)


class BaseFactory(SQLAlchemyFactory[T]):
__is_base_factory__ = True
__set_relationships__ = True
__randomize_collection_length__ = True
__min_collection_length__ = 3


def test_custom_sqla_factory() -> None:
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
session = Session(engine)

BaseFactory.__session__ = session # Or using a callable that returns a session

author = BaseFactory.create_factory(Author).create_sync()
assert author.id is not None
assert author.id == author.books[0].author_id

book = BaseFactory.create_factory(Book).create_sync()
assert book.id is not None
assert book.author.books == [book]
10 changes: 10 additions & 0 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ By default, this will add generated models to the session and then commit. This
Similarly for ``__async_session__`` and ``create_async``.


Adding global overrides
------------------------------

By combining the above and using other settings, a global base factory can be set up for other factories.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py
:caption: Using persistence
:language: python


API reference
------------------------------
Full API docs are available :class:`here <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory>`.
119 changes: 86 additions & 33 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand Down Expand Up @@ -41,6 +42,7 @@
Mapping,
Sequence,
Type,
TypedDict,
TypeVar,
cast,
)
Expand All @@ -65,14 +67,7 @@
unwrap_optional,
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
is_literal,
is_optional,
is_safe_subclass,
is_union,
)
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
Expand All @@ -88,11 +83,7 @@
from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes
from polyfactory.value_generators.constrained_url import handle_constrained_url
from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid
from polyfactory.value_generators.primitives import (
create_random_boolean,
create_random_bytes,
create_random_string,
)
from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand All @@ -105,6 +96,17 @@
F = TypeVar("F", bound="BaseFactory[Any]")


class BuildContext(TypedDict):
seen_models: set[type]


def _get_build_context(build_context: BuildContext | None) -> BuildContext:
if build_context is None:
return {"seen_models": set()}

return copy.deepcopy(build_context)


class BaseFactory(ABC, Generic[T]):
"""Base Factory class - this class holds the main logic of the library"""

Expand Down Expand Up @@ -243,10 +245,7 @@ class Foo(ModelFactory[MyModel]): # <<< MyModel
generic_args: Sequence[type[T]] = [
arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar)
]
if len(generic_args) != 1:
return None

return generic_args[0]
return None if len(generic_args) != 1 else generic_args[0]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BaseFactory._infer_model_type refactored with the following changes:


@classmethod
def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]:
Expand Down Expand Up @@ -277,7 +276,12 @@ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
)

@classmethod
def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
def _handle_factory_field(
cls,
field_value: Any,
build_context: BuildContext,
field_build_parameters: Any | None = None,
) -> Any:
"""Handle a value defined on the factory class itself.

:param field_value: A value defined as an attribute on the factory class.
Expand All @@ -287,12 +291,14 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return field_value.build(**field_build_parameters)
return field_value.build(_build_context=build_context, **field_build_parameters)

if isinstance(field_build_parameters, Sequence):
return [field_value.build(**parameter) for parameter in field_build_parameters]
return [
field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters
]

return field_value.build()
return field_value.build(_build_context=build_context)

if isinstance(field_value, Use):
return field_value.to_value()
Expand All @@ -303,7 +309,12 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
return field_value() if callable(field_value) else field_value

@classmethod
def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
def _handle_factory_field_coverage(
cls,
field_value: Any,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Handle a value defined on the factory class itself.

:param field_value: A value defined as an attribute on the factory class.
Expand All @@ -313,10 +324,13 @@ def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return CoverageContainer(field_value.coverage(**field_build_parameters))
return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters))

if isinstance(field_build_parameters, Sequence):
return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters]
return [
CoverageContainer(field_value.coverage(_build_context=build_context, **parameter))
for parameter in field_build_parameters
]

return CoverageContainer(field_value.coverage())

Expand Down Expand Up @@ -621,15 +635,18 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

:param field_meta: FieldMeta instance.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:param build_context: BuildContext data for current build.

:returns: An arbitrary value.

"""
build_context = _get_build_context(build_context)
if cls.is_ignored_type(field_meta.annotation):
return None

Expand All @@ -648,20 +665,32 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta)

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
return None if is_optional(field_meta.annotation) else Null

return cls._get_or_create_factory(model=unwrapped_annotation).build(
_build_context=build_context,
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
)

if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
factory = cls._get_or_create_factory(model=field_meta.type_args[0])
if isinstance(field_build_parameters, Sequence):
return [factory.build(**field_parameters) for field_parameters in field_build_parameters]
return [
factory.build(_build_context=build_context, **field_parameters)
for field_parameters in field_build_parameters
]

if field_meta.type_args[0] in build_context["seen_models"]:
return []

if not cls.__randomize_collection_length__:
return [factory.build()]
return [factory.build(_build_context=build_context)]

batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__)
return factory.batch(size=batch_size)
return factory.batch(size=batch_size, _build_context=build_context)

if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection):
if cls.__randomize_collection_length__:
collection_type = get_collection_type(unwrapped_annotation)
if collection_type != dict:
Expand All @@ -682,8 +711,9 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912

return handle_collection_type(field_meta, origin, cls)

if is_union(field_meta.annotation) and field_meta.children:
return cls.get_field_value(cls.__random__.choice(field_meta.children))
if is_union(unwrapped_annotation) and field_meta.children:
children = [child for child in field_meta.children if child.annotation not in build_context["seen_models"]]
return cls.get_field_value(cls.__random__.choice(children))

if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
return create_random_string(cls.__random__, min_length=1, max_length=10)
Expand All @@ -707,11 +737,13 @@ def get_field_value_coverage( # noqa: C901
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Iterable[Any]:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

:param field_meta: FieldMeta instance.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:param build_context: BuildContext data for current build.

:returns: An iterable of values.

Expand Down Expand Up @@ -739,6 +771,7 @@ def get_field_value_coverage( # noqa: C901
elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
yield CoverageContainer(
cls._get_or_create_factory(model=unwrapped_annotation).coverage(
_build_context=build_context,
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
),
)
Expand Down Expand Up @@ -861,6 +894,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
:returns: A dictionary of build results.

"""
_build_context = _get_build_context(kwargs.pop("_build_context", None))
_build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand All @@ -883,10 +919,19 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
result[field_meta.name] = cls._handle_factory_field(
field_value=field_value,
field_build_parameters=field_build_parameters,
build_context=_build_context,
)
continue

result[field_meta.name] = cls.get_field_value(field_meta, field_build_parameters=field_build_parameters)
field_result = cls.get_field_value(
field_meta,
field_build_parameters=field_build_parameters,
build_context=_build_context,
)
if field_result is Null:
continue

result[field_meta.name] = field_result

for field_name, post_generator in generate_post.items():
result[field_name] = post_generator.to_value(field_name, result)
Expand All @@ -898,10 +943,14 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
"""Process the given kwargs and generate values for the factory's model.

:param kwargs: Any build kwargs.
:param build_context: BuildContext data for current build.

:returns: A dictionary of build results.

"""
_build_context = _get_build_context(kwargs.pop("_build_context", None))
_build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand All @@ -925,11 +974,16 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
result[field_meta.name] = cls._handle_factory_field_coverage(
field_value=field_value,
field_build_parameters=field_build_parameters,
build_context=_build_context,
)
continue

result[field_meta.name] = CoverageContainer(
cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters),
cls.get_field_value_coverage(
field_meta,
field_build_parameters=field_build_parameters,
build_context=_build_context,
),
)

for resolved in resolve_kwargs_coverage(result):
Expand All @@ -946,7 +1000,6 @@ def build(cls, **kwargs: Any) -> T:
:returns: An instance of type T.

"""

return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))

@classmethod
Expand Down
Loading
Loading