Skip to content

Commit

Permalink
feat: add recursion guard
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Jan 2, 2024
1 parent 8dc8e1a commit 5c0356d
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 84 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List

from sqlalchemy import ForeignKey, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory, T


class Base(DeclarativeBase):
...


class Author(Base):
__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

books: Mapped[List["Book"]] = relationship(
"Book",
uselist=True,
back_populates="author",
)


class Book(Base):
__tablename__ = "books"

id: Mapped[int] = mapped_column(primary_key=True)
author_id: Mapped[int] = mapped_column(ForeignKey(Author.id), nullable=False)
author: Mapped[Author] = relationship(
"Author",
uselist=False,
back_populates="books",
)


class BaseFactory(SQLAlchemyFactory[T]):
__is_base_factory__ = True
__set_relationships__ = True
__randomize_collection_length__ = True
__min_collection_length__ = 3


def test_custom_sqla_factory() -> None:
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
session = Session(engine)

BaseFactory.__session__ = session # Or using a callable that returns a session

author = BaseFactory.create_factory(Author).create_sync()
assert author.id is not None
assert author.id == author.books[0].author_id

book = BaseFactory.create_factory(Book).create_sync()
assert book.id is not None
assert book.author.books == [book]

BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001
10 changes: 10 additions & 0 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ By default, this will add generated models to the session and then commit. This
Similarly for ``__async_session__`` and ``create_async``.


Adding global overrides
------------------------------

By combining the above and using other settings, a global base factory can be set up for other factories.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py
:caption: Using persistence
:language: python


API reference
------------------------------
Full API docs are available :class:`here <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory>`.
147 changes: 101 additions & 46 deletions pdm.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion polyfactory/factories/attrs_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def resolve_types(cls, model: type[T], **kwargs: Any) -> None:
:param kwargs: Any parameters that need to be passed to `attrs.resolve_types`.
"""

attrs.resolve_types(model, **kwargs) # type: ignore[type-var]
attrs.resolve_types(model, **kwargs)
85 changes: 57 additions & 28 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Mapping,
Sequence,
Type,
TypedDict,
TypeVar,
cast,
)
Expand All @@ -63,14 +64,7 @@
unwrap_optional,
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
is_literal,
is_optional,
is_safe_subclass,
is_union,
)
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
Expand All @@ -86,11 +80,7 @@
from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes
from polyfactory.value_generators.constrained_url import handle_constrained_url
from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid
from polyfactory.value_generators.primitives import (
create_random_boolean,
create_random_bytes,
create_random_string,
)
from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand All @@ -103,6 +93,15 @@
F = TypeVar("F", bound="BaseFactory[Any]")


class BuildContext(TypedDict):
seen_models: set[type]


def _get_build_context(build_context: BuildContext | None) -> BuildContext:
build_context = build_context or {"seen_models": set()}
return build_context.copy()


class BaseFactory(ABC, Generic[T]):
"""Base Factory class - this class holds the main logic of the library"""

Expand Down Expand Up @@ -271,7 +270,12 @@ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
)

@classmethod
def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
def _handle_factory_field(
cls,
field_value: Any,
build_context: BuildContext,
field_build_parameters: Any | None = None,
) -> Any:
"""Handle a value defined on the factory class itself.
:param field_value: A value defined as an attribute on the factory class.
Expand All @@ -281,12 +285,12 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return field_value.build(**field_build_parameters)
return field_value.build(build_context, **field_build_parameters)

if isinstance(field_build_parameters, Sequence):
return [field_value.build(**parameter) for parameter in field_build_parameters]
return [field_value.build(build_context, **parameter) for parameter in field_build_parameters]

return field_value.build()
return field_value.build(build_context)

if isinstance(field_value, Use):
return field_value.to_value()
Expand Down Expand Up @@ -615,6 +619,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a field value on the subclass if existing, otherwise returns a mock value.
Expand All @@ -624,6 +629,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
:returns: An arbitrary value.
"""
build_context = _get_build_context(build_context)
if cls.is_ignored_type(field_meta.annotation):
return None

Expand All @@ -642,18 +648,27 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta)

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
return None

return cls._get_or_create_factory(model=unwrapped_annotation).build(
build_context,
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
)

if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
factory = cls._get_or_create_factory(model=field_meta.type_args[0])
if isinstance(field_build_parameters, Sequence):
return [factory.build(**field_parameters) for field_parameters in field_build_parameters]
return [factory.build(build_context, **field_parameters) for field_parameters in field_build_parameters]

if field_meta.type_args[0] in build_context["seen_models"]:
return []

if not cls.__randomize_collection_length__:
return [factory.build()]
return [factory.build(build_context)]

batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__)
return factory.batch(size=batch_size)
return factory.batch(size=batch_size, build_context=build_context)

if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
if cls.__randomize_collection_length__:
Expand Down Expand Up @@ -833,14 +848,17 @@ def _check_declared_fields_exist_in_model(cls) -> None:
raise ConfigurationException(error_message)

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]:
"""Process the given kwargs and generate values for the factory's model.
:param kwargs: Any build kwargs.
:returns: A dictionary of build results.
"""
build_context = _get_build_context(build_context)
build_context["seen_models"] = build_context["seen_models"] | {cls.__model__}

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand All @@ -863,25 +881,37 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
result[field_meta.name] = cls._handle_factory_field(
field_value=field_value,
field_build_parameters=field_build_parameters,
build_context=build_context,
)
continue

result[field_meta.name] = cls.get_field_value(field_meta, field_build_parameters=field_build_parameters)
result[field_meta.name] = cls.get_field_value(
field_meta,
field_build_parameters=field_build_parameters,
build_context=build_context,
)

for field_name, post_generator in generate_post.items():
result[field_name] = post_generator.to_value(field_name, result)

return result

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
def process_kwargs_coverage(
cls,
build_context: BuildContext | None = None,
**kwargs: Any,
) -> abc.Iterable[dict[str, Any]]:
"""Process the given kwargs and generate values for the factory's model.
:param kwargs: Any build kwargs.
:returns: A dictionary of build results.
"""
build_context = _get_build_context(build_context)
build_context["seen_models"] = build_context["seen_models"] | {cls.__model__}

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand Down Expand Up @@ -918,19 +948,18 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
yield resolved

@classmethod
def build(cls, **kwargs: Any) -> T:
def build(cls, build_context: BuildContext | None = None, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__
:param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.
:returns: An instance of type T.
"""

return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
return cast("T", cls.__model__(**cls.process_kwargs(build_context, **kwargs)))

@classmethod
def batch(cls, size: int, **kwargs: Any) -> list[T]:
def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: Any) -> list[T]:
"""Build a batch of size n of the factory's Meta.model.
:param size: Size of the batch.
Expand All @@ -939,7 +968,7 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
:returns: A list of instances of type T.
"""
return [cls.build(**kwargs) for _ in range(size)]
return [cls.build(build_context, **kwargs) for _ in range(size)]

@classmethod
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
Expand Down
14 changes: 12 additions & 2 deletions polyfactory/factories/beanie_odm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if TYPE_CHECKING:
from typing_extensions import TypeGuard

from polyfactory.factories.base import BuildContext
from polyfactory.field_meta import FieldMeta

try:
Expand Down Expand Up @@ -55,7 +56,12 @@ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":
return is_safe_subclass(value, Document)

@classmethod
def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | None = None) -> Any:
def get_field_value(
cls,
field_meta: "FieldMeta",
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a field value on the subclass if existing, otherwise returns a mock value.
:param field_meta: FieldMeta instance.
Expand All @@ -74,4 +80,8 @@ def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any |
field_meta.annotation = link_class
field_meta.annotation = link_class

return super().get_field_value(field_meta=field_meta, field_build_parameters=field_build_parameters)
return super().get_field_value(
field_meta=field_meta,
field_build_parameters=field_build_parameters,
build_context=build_context,
)
11 changes: 8 additions & 3 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.field_meta import Constraints, FieldMeta, Null
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
Expand Down Expand Up @@ -368,7 +368,12 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) ->
return super().get_constrained_field_value(annotation, field_meta)

@classmethod
def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T:
def build(
cls,
build_context: BuildContext | None = None,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
"""Build an instance of the factory's __model__
:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
Expand All @@ -378,7 +383,7 @@ def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T:
:returns: An instance of type T.
"""
processed_kwargs = cls.process_kwargs(**kwargs)
processed_kwargs = cls.process_kwargs(build_context, **kwargs)

if factory_use_construct:
return (
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ dependencies = [
]

[project.optional-dependencies]
sqlalchemy = ["sqlalchemy>=1.4.29",]
sqlalchemy = [
"sqlalchemy[asyncio]>=1.4.29",
]
pydantic = ["pydantic[email]",]
msgspec = ["msgspec",]
odmantic = ["odmantic<1.0.0", "pydantic[email]",]
Expand Down
Loading

0 comments on commit 5c0356d

Please sign in to comment.