From 5c0356d2801198061784811fe6eda1f63024b57b Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 1 Jan 2024 19:49:52 +0000 Subject: [PATCH 01/11] feat: add recursion guard --- .../sqlalchemy_factory/test_example_4.py | 60 +++++++ .../library_factories/sqlalchemy_factory.rst | 10 ++ pdm.lock | 147 ++++++++++++------ polyfactory/factories/attrs_factory.py | 2 +- polyfactory/factories/base.py | 85 ++++++---- polyfactory/factories/beanie_odm_factory.py | 14 +- polyfactory/factories/pydantic_factory.py | 11 +- pyproject.toml | 4 +- .../test_sqlalchemy_factory_common.py | 21 +++ tests/test_beanie_factory.py | 2 +- ...t_passing_build_args_to_child_factories.py | 4 +- tests/test_recursive_models.py | 39 +++++ 12 files changed, 315 insertions(+), 84 deletions(-) create mode 100644 docs/examples/library_factories/sqlalchemy_factory/test_example_4.py create mode 100644 tests/test_recursive_models.py diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py new file mode 100644 index 00000000..b0af4135 --- /dev/null +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py @@ -0,0 +1,60 @@ +from typing import List + +from sqlalchemy import ForeignKey, create_engine +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship + +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory, T + + +class Base(DeclarativeBase): + ... + + +class Author(Base): + __tablename__ = "authors" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + books: Mapped[List["Book"]] = relationship( + "Book", + uselist=True, + back_populates="author", + ) + + +class Book(Base): + __tablename__ = "books" + + id: Mapped[int] = mapped_column(primary_key=True) + author_id: Mapped[int] = mapped_column(ForeignKey(Author.id), nullable=False) + author: Mapped[Author] = relationship( + "Author", + uselist=False, + back_populates="books", + ) + + +class BaseFactory(SQLAlchemyFactory[T]): + __is_base_factory__ = True + __set_relationships__ = True + __randomize_collection_length__ = True + __min_collection_length__ = 3 + + +def test_custom_sqla_factory() -> None: + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + session = Session(engine) + + BaseFactory.__session__ = session # Or using a callable that returns a session + + author = BaseFactory.create_factory(Author).create_sync() + assert author.id is not None + assert author.id == author.books[0].author_id + + book = BaseFactory.create_factory(Book).create_sync() + assert book.id is not None + assert book.author.books == [book] + + BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001 diff --git a/docs/usage/library_factories/sqlalchemy_factory.rst b/docs/usage/library_factories/sqlalchemy_factory.rst index ef864dc7..ee72c9e9 100644 --- a/docs/usage/library_factories/sqlalchemy_factory.rst +++ b/docs/usage/library_factories/sqlalchemy_factory.rst @@ -37,6 +37,16 @@ By default, this will add generated models to the session and then commit. This Similarly for ``__async_session__`` and ``create_async``. +Adding global overrides +------------------------------ + +By combining the above and using other settings, a global base factory can be set up for other factories. + +.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py + :caption: Using persistence + :language: python + + API reference ------------------------------ Full API docs are available :class:`here `. diff --git a/pdm.lock b/pdm.lock index 20e229a1..6d1d6728 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "msgspec", "sqlalchemy", "lint", "attrs", "full", "docs", "odmantic", "pydantic", "test", "dev", "beanie"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:fdbe2e08c7fa2713913bf0f53e45ac79993264869568e92094eb4acebedf6274" +content_hash = "sha256:62840909635332a37b19c20794d822e617c142fa856f00d88098338d405ef77c" [[package]] name = "accessible-pygments" @@ -283,7 +283,7 @@ version = "23.12.0" requires_python = ">=3.8" summary = "The uncompromising code formatter." dependencies = [ - "aiohttp>=3.7.4; sys_platform != \"win32\" or implementation_name != \"pypy\" and extra == \"d\"", + "aiohttp>=3.7.4; sys_platform != \"win32\" or implementation_name != \"pypy\"", "click>=8.0.0", "mypy-extensions>=0.4.3", "packaging>=22.0", @@ -990,7 +990,7 @@ files = [ [[package]] name = "litestar-sphinx-theme" version = "0.2.0" -requires_python = ">=3.8,<4.0" +requires_python = "<4.0,>=3.8" git = "https://github.com/litestar-org/litestar-sphinx-theme.git" revision = "c5ce66aadc8f910c24f54bf0d172798c237a67eb" summary = "A Sphinx theme for the Litestar organization" @@ -2205,7 +2205,7 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.23" +version = "2.0.24" requires_python = ">=3.7" summary = "Database Abstraction Library" dependencies = [ @@ -2213,48 +2213,103 @@ dependencies = [ "typing-extensions>=4.2.0", ] files = [ - {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"}, - {file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"}, - {file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"}, - {file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"}, - {file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"}, - {file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"}, - {file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"}, - {file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f801d85ba4753d4ed97181d003e5d3fa330ac7c4587d131f61d7f968f416862"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b35c35e3923ade1e7ac44e150dec29f5863513246c8bf85e2d7d313e3832bcfb"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9b3fd5eca3c0b137a5e0e468e24ca544ed8ca4783e0e55341b7ed2807518ee"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6209e689d0ff206c40032b6418e3cfcfc5af044b3f66e381d7f1ae301544b4"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:37e89d965b52e8b20571b5d44f26e2124b26ab63758bf1b7598a0e38fb2c4005"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6910eb4ea90c0889f363965cd3c8c45a620ad27b526a7899f0054f6c1b9219e"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-win32.whl", hash = "sha256:d8e7e8a150e7b548e7ecd6ebb9211c37265991bf2504297d9454e01b58530fc6"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-win_amd64.whl", hash = "sha256:396f05c552f7fa30a129497c41bef5b4d1423f9af8fe4df0c3dcd38f3e3b9a14"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:adbd67dac4ebf54587198b63cd30c29fd7eafa8c0cab58893d9419414f8efe4b"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0f611b431b84f55779cbb7157257d87b4a2876b067c77c4f36b15e44ced65e2"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56a0e90a959e18ac5f18c80d0cad9e90cb09322764f536e8a637426afb1cae2f"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6db686a1d9f183c639f7e06a2656af25d4ed438eda581de135d15569f16ace33"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f0cc0b486a56dff72dddae6b6bfa7ff201b0eeac29d4bc6f0e9725dc3c360d71"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a1d4856861ba9e73bac05030cec5852eabfa9ef4af8e56c19d92de80d46fc34"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-win32.whl", hash = "sha256:a3c2753bf4f48b7a6024e5e8a394af49b1b12c817d75d06942cae03d14ff87b3"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-win_amd64.whl", hash = "sha256:38732884eabc64982a09a846bacf085596ff2371e4e41d20c0734f7e50525d01"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9f992e0f916201731993eab8502912878f02287d9f765ef843677ff118d0e0b1"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2587e108463cc2e5b45a896b2e7cc8659a517038026922a758bde009271aed11"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bb7cedcddffca98c40bb0becd3423e293d1fef442b869da40843d751785beb3"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83fa6df0e035689df89ff77a46bf8738696785d3156c2c61494acdcddc75c69d"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc889fda484d54d0b31feec409406267616536d048a450fc46943e152700bb79"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57ef6f2cb8b09a042d0dbeaa46a30f2df5dd1e1eb889ba258b0d5d7d6011b81c"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-win32.whl", hash = "sha256:ea490564435b5b204d8154f0e18387b499ea3cedc1e6af3b3a2ab18291d85aa7"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-win_amd64.whl", hash = "sha256:ccfd336f96d4c9bbab0309f2a565bf15c468c2d8b2d277a32f89c5940f71fcf9"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9bafaa05b19dc07fa191c1966c5e852af516840b0d7b46b7c3303faf1a349bc9"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e69290b921b7833c04206f233d6814c60bee1d135b09f5ae5d39229de9b46cd4"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8398593ccc4440ce6dffcc4f47d9b2d72b9fe7112ac12ea4a44e7d4de364db1"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f073321a79c81e1a009218a21089f61d87ee5fa3c9563f6be94f8b41ff181812"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9036ebfd934813990c5b9f71f297e77ed4963720db7d7ceec5a3fdb7cd2ef6ce"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcf84fe93397a0f67733aa2a38ed4eab9fc6348189fc950e656e1ea198f45668"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-win32.whl", hash = "sha256:6f5e75de91c754365c098ac08c13fdb267577ce954fa239dd49228b573ca88d7"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-win_amd64.whl", hash = "sha256:9f29c7f0f4b42337ec5a779e166946a9f86d7d56d827e771b69ecbdf426124ac"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:07cc423892f2ceda9ae1daa28c0355757f362ecc7505b1ab1a3d5d8dc1c44ac6"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a479aa1ab199178ff1956b09ca8a0693e70f9c762875d69292d37049ffd0d8f"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b8d0e8578e7f853f45f4512b5c920f6a546cd4bed44137460b2a56534644205"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17e7e27af178d31b436dda6a596703b02a89ba74a15e2980c35ecd9909eea3a"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1ca7903d5e7db791a355b579c690684fac6304478b68efdc7f2ebdcfe770d8d7"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db09e424d7bb89b6215a184ca93b4f29d7f00ea261b787918a1af74143b98c06"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-win32.whl", hash = "sha256:a5cd7d30e47f87b21362beeb3e86f1b5886e7d9b0294b230dde3d3f4a1591375"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-win_amd64.whl", hash = "sha256:7ae5d44517fe81079ce75cf10f96978284a6db2642c5932a69c82dbae09f009a"}, + {file = "SQLAlchemy-2.0.24-py3-none-any.whl", hash = "sha256:8f358f5cfce04417b6ff738748ca4806fe3d3ae8040fb4e6a0c9a6973ccf9b6e"}, + {file = "SQLAlchemy-2.0.24.tar.gz", hash = "sha256:6db97656fd3fe3f7e5b077f12fa6adb5feb6e0b567a3e99f47ecf5f7ea0a09e3"}, +] + +[[package]] +name = "sqlalchemy" +version = "2.0.24" +extras = ["asyncio"] +requires_python = ">=3.7" +summary = "Database Abstraction Library" +dependencies = [ + "greenlet!=0.4.17", + "sqlalchemy==2.0.24", +] +files = [ + {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f801d85ba4753d4ed97181d003e5d3fa330ac7c4587d131f61d7f968f416862"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b35c35e3923ade1e7ac44e150dec29f5863513246c8bf85e2d7d313e3832bcfb"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9b3fd5eca3c0b137a5e0e468e24ca544ed8ca4783e0e55341b7ed2807518ee"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6209e689d0ff206c40032b6418e3cfcfc5af044b3f66e381d7f1ae301544b4"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:37e89d965b52e8b20571b5d44f26e2124b26ab63758bf1b7598a0e38fb2c4005"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6910eb4ea90c0889f363965cd3c8c45a620ad27b526a7899f0054f6c1b9219e"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-win32.whl", hash = "sha256:d8e7e8a150e7b548e7ecd6ebb9211c37265991bf2504297d9454e01b58530fc6"}, + {file = "SQLAlchemy-2.0.24-cp310-cp310-win_amd64.whl", hash = "sha256:396f05c552f7fa30a129497c41bef5b4d1423f9af8fe4df0c3dcd38f3e3b9a14"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:adbd67dac4ebf54587198b63cd30c29fd7eafa8c0cab58893d9419414f8efe4b"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0f611b431b84f55779cbb7157257d87b4a2876b067c77c4f36b15e44ced65e2"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56a0e90a959e18ac5f18c80d0cad9e90cb09322764f536e8a637426afb1cae2f"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6db686a1d9f183c639f7e06a2656af25d4ed438eda581de135d15569f16ace33"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f0cc0b486a56dff72dddae6b6bfa7ff201b0eeac29d4bc6f0e9725dc3c360d71"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a1d4856861ba9e73bac05030cec5852eabfa9ef4af8e56c19d92de80d46fc34"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-win32.whl", hash = "sha256:a3c2753bf4f48b7a6024e5e8a394af49b1b12c817d75d06942cae03d14ff87b3"}, + {file = "SQLAlchemy-2.0.24-cp311-cp311-win_amd64.whl", hash = "sha256:38732884eabc64982a09a846bacf085596ff2371e4e41d20c0734f7e50525d01"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9f992e0f916201731993eab8502912878f02287d9f765ef843677ff118d0e0b1"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2587e108463cc2e5b45a896b2e7cc8659a517038026922a758bde009271aed11"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bb7cedcddffca98c40bb0becd3423e293d1fef442b869da40843d751785beb3"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83fa6df0e035689df89ff77a46bf8738696785d3156c2c61494acdcddc75c69d"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc889fda484d54d0b31feec409406267616536d048a450fc46943e152700bb79"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57ef6f2cb8b09a042d0dbeaa46a30f2df5dd1e1eb889ba258b0d5d7d6011b81c"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-win32.whl", hash = "sha256:ea490564435b5b204d8154f0e18387b499ea3cedc1e6af3b3a2ab18291d85aa7"}, + {file = "SQLAlchemy-2.0.24-cp312-cp312-win_amd64.whl", hash = "sha256:ccfd336f96d4c9bbab0309f2a565bf15c468c2d8b2d277a32f89c5940f71fcf9"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9bafaa05b19dc07fa191c1966c5e852af516840b0d7b46b7c3303faf1a349bc9"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e69290b921b7833c04206f233d6814c60bee1d135b09f5ae5d39229de9b46cd4"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8398593ccc4440ce6dffcc4f47d9b2d72b9fe7112ac12ea4a44e7d4de364db1"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f073321a79c81e1a009218a21089f61d87ee5fa3c9563f6be94f8b41ff181812"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9036ebfd934813990c5b9f71f297e77ed4963720db7d7ceec5a3fdb7cd2ef6ce"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcf84fe93397a0f67733aa2a38ed4eab9fc6348189fc950e656e1ea198f45668"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-win32.whl", hash = "sha256:6f5e75de91c754365c098ac08c13fdb267577ce954fa239dd49228b573ca88d7"}, + {file = "SQLAlchemy-2.0.24-cp38-cp38-win_amd64.whl", hash = "sha256:9f29c7f0f4b42337ec5a779e166946a9f86d7d56d827e771b69ecbdf426124ac"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:07cc423892f2ceda9ae1daa28c0355757f362ecc7505b1ab1a3d5d8dc1c44ac6"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a479aa1ab199178ff1956b09ca8a0693e70f9c762875d69292d37049ffd0d8f"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b8d0e8578e7f853f45f4512b5c920f6a546cd4bed44137460b2a56534644205"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17e7e27af178d31b436dda6a596703b02a89ba74a15e2980c35ecd9909eea3a"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1ca7903d5e7db791a355b579c690684fac6304478b68efdc7f2ebdcfe770d8d7"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db09e424d7bb89b6215a184ca93b4f29d7f00ea261b787918a1af74143b98c06"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-win32.whl", hash = "sha256:a5cd7d30e47f87b21362beeb3e86f1b5886e7d9b0294b230dde3d3f4a1591375"}, + {file = "SQLAlchemy-2.0.24-cp39-cp39-win_amd64.whl", hash = "sha256:7ae5d44517fe81079ce75cf10f96978284a6db2642c5932a69c82dbae09f009a"}, + {file = "SQLAlchemy-2.0.24-py3-none-any.whl", hash = "sha256:8f358f5cfce04417b6ff738748ca4806fe3d3ae8040fb4e6a0c9a6973ccf9b6e"}, + {file = "SQLAlchemy-2.0.24.tar.gz", hash = "sha256:6db97656fd3fe3f7e5b077f12fa6adb5feb6e0b567a3e99f47ecf5f7ea0a09e3"}, ] [[package]] diff --git a/polyfactory/factories/attrs_factory.py b/polyfactory/factories/attrs_factory.py index f9302845..00ffa033 100644 --- a/polyfactory/factories/attrs_factory.py +++ b/polyfactory/factories/attrs_factory.py @@ -79,4 +79,4 @@ def resolve_types(cls, model: type[T], **kwargs: Any) -> None: :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`. """ - attrs.resolve_types(model, **kwargs) # type: ignore[type-var] + attrs.resolve_types(model, **kwargs) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 242b73cd..0d46be7a 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -39,6 +39,7 @@ Mapping, Sequence, Type, + TypedDict, TypeVar, cast, ) @@ -63,14 +64,7 @@ unwrap_optional, ) from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage -from polyfactory.utils.predicates import ( - get_type_origin, - is_any, - is_literal, - is_optional, - is_safe_subclass, - is_union, -) +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, @@ -86,11 +80,7 @@ from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes from polyfactory.value_generators.constrained_url import handle_constrained_url from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid -from polyfactory.value_generators.primitives import ( - create_random_boolean, - create_random_bytes, - create_random_string, -) +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string if TYPE_CHECKING: from typing_extensions import TypeGuard @@ -103,6 +93,15 @@ F = TypeVar("F", bound="BaseFactory[Any]") +class BuildContext(TypedDict): + seen_models: set[type] + + +def _get_build_context(build_context: BuildContext | None) -> BuildContext: + build_context = build_context or {"seen_models": set()} + return build_context.copy() + + class BaseFactory(ABC, Generic[T]): """Base Factory class - this class holds the main logic of the library""" @@ -271,7 +270,12 @@ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: ) @classmethod - def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + def _handle_factory_field( + cls, + field_value: Any, + build_context: BuildContext, + field_build_parameters: Any | None = None, + ) -> Any: """Handle a value defined on the factory class itself. :param field_value: A value defined as an attribute on the factory class. @@ -281,12 +285,12 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return field_value.build(**field_build_parameters) + return field_value.build(build_context, **field_build_parameters) if isinstance(field_build_parameters, Sequence): - return [field_value.build(**parameter) for parameter in field_build_parameters] + return [field_value.build(build_context, **parameter) for parameter in field_build_parameters] - return field_value.build() + return field_value.build(build_context) if isinstance(field_value, Use): return field_value.to_value() @@ -615,6 +619,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, ) -> Any: """Return a field value on the subclass if existing, otherwise returns a mock value. @@ -624,6 +629,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 :returns: An arbitrary value. """ + build_context = _get_build_context(build_context) if cls.is_ignored_type(field_meta.annotation): return None @@ -642,18 +648,27 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) if BaseFactory.is_factory_type(annotation=unwrapped_annotation): + if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: + return None + return cls._get_or_create_factory(model=unwrapped_annotation).build( + build_context, **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), ) if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): factory = cls._get_or_create_factory(model=field_meta.type_args[0]) if isinstance(field_build_parameters, Sequence): - return [factory.build(**field_parameters) for field_parameters in field_build_parameters] + return [factory.build(build_context, **field_parameters) for field_parameters in field_build_parameters] + + if field_meta.type_args[0] in build_context["seen_models"]: + return [] + if not cls.__randomize_collection_length__: - return [factory.build()] + return [factory.build(build_context)] + batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) - return factory.batch(size=batch_size) + return factory.batch(size=batch_size, build_context=build_context) if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): if cls.__randomize_collection_length__: @@ -833,7 +848,7 @@ def _check_declared_fields_exist_in_model(cls) -> None: raise ConfigurationException(error_message) @classmethod - def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: + def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. @@ -841,6 +856,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: :returns: A dictionary of build results. """ + build_context = _get_build_context(build_context) + build_context["seen_models"] = build_context["seen_models"] | {cls.__model__} + result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -863,10 +881,15 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: result[field_meta.name] = cls._handle_factory_field( field_value=field_value, field_build_parameters=field_build_parameters, + build_context=build_context, ) continue - result[field_meta.name] = cls.get_field_value(field_meta, field_build_parameters=field_build_parameters) + result[field_meta.name] = cls.get_field_value( + field_meta, + field_build_parameters=field_build_parameters, + build_context=build_context, + ) for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) @@ -874,7 +897,11 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: return result @classmethod - def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: + def process_kwargs_coverage( + cls, + build_context: BuildContext | None = None, + **kwargs: Any, + ) -> abc.Iterable[dict[str, Any]]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. @@ -882,6 +909,9 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: :returns: A dictionary of build results. """ + build_context = _get_build_context(build_context) + build_context["seen_models"] = build_context["seen_models"] | {cls.__model__} + result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -918,7 +948,7 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: yield resolved @classmethod - def build(cls, **kwargs: Any) -> T: + def build(cls, build_context: BuildContext | None = None, **kwargs: Any) -> T: """Build an instance of the factory's __model__ :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. @@ -926,11 +956,10 @@ def build(cls, **kwargs: Any) -> T: :returns: An instance of type T. """ - - return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + return cast("T", cls.__model__(**cls.process_kwargs(build_context, **kwargs))) @classmethod - def batch(cls, size: int, **kwargs: Any) -> list[T]: + def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: Any) -> list[T]: """Build a batch of size n of the factory's Meta.model. :param size: Size of the batch. @@ -939,7 +968,7 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]: :returns: A list of instances of type T. """ - return [cls.build(**kwargs) for _ in range(size)] + return [cls.build(build_context, **kwargs) for _ in range(size)] @classmethod def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: diff --git a/polyfactory/factories/beanie_odm_factory.py b/polyfactory/factories/beanie_odm_factory.py index 8f98a5cf..ddd31697 100644 --- a/polyfactory/factories/beanie_odm_factory.py +++ b/polyfactory/factories/beanie_odm_factory.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from typing_extensions import TypeGuard + from polyfactory.factories.base import BuildContext from polyfactory.field_meta import FieldMeta try: @@ -55,7 +56,12 @@ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": return is_safe_subclass(value, Document) @classmethod - def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | None = None) -> Any: + def get_field_value( + cls, + field_meta: "FieldMeta", + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. @@ -74,4 +80,8 @@ def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | field_meta.annotation = link_class field_meta.annotation = link_class - return super().get_field_value(field_meta=field_meta, field_build_parameters=field_build_parameters) + return super().get_field_value( + field_meta=field_meta, + field_build_parameters=field_build_parameters, + build_context=build_context, + ) diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 64102e2c..8998bfa6 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -13,7 +13,7 @@ from polyfactory.collection_extender import CollectionExtender from polyfactory.constants import DEFAULT_RANDOM from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory +from polyfactory.factories.base import BaseFactory, BuildContext from polyfactory.field_meta import Constraints, FieldMeta, Null from polyfactory.utils.deprecation import check_for_deprecated_parameters from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional @@ -368,7 +368,12 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> return super().get_constrained_field_value(annotation, field_meta) @classmethod - def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T: + def build( + cls, + build_context: BuildContext | None = None, + factory_use_construct: bool = False, + **kwargs: Any, + ) -> T: """Build an instance of the factory's __model__ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the @@ -378,7 +383,7 @@ def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T: :returns: An instance of type T. """ - processed_kwargs = cls.process_kwargs(**kwargs) + processed_kwargs = cls.process_kwargs(build_context, **kwargs) if factory_use_construct: return ( diff --git a/pyproject.toml b/pyproject.toml index 9644eb01..b144a0e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,9 @@ dependencies = [ ] [project.optional-dependencies] -sqlalchemy = ["sqlalchemy>=1.4.29",] +sqlalchemy = [ + "sqlalchemy[asyncio]>=1.4.29", +] pydantic = ["pydantic[email]",] msgspec = ["msgspec",] odmantic = ["odmantic<1.0.0", "pydantic[email]",] diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 2bb7873e..8cdcf667 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -10,6 +10,7 @@ from sqlalchemy.orm.decl_api import DeclarativeMeta, registry from polyfactory.exceptions import ConfigurationException +from polyfactory.factories.base import BaseFactory from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -198,6 +199,26 @@ class AuthorFactory(SQLAlchemyFactory[Author]): assert isinstance(result.books[0], Book) +def test_sqla_factory_create() -> None: + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + + class OverridenSQLAlchemyFactory(SQLAlchemyFactory): + __is_base_factory__ = True + __session__ = Session(engine) + __set_relationships__ = True + + author: Author = OverridenSQLAlchemyFactory.create_factory(Author).create_sync() + assert isinstance(author.books[0], Book) + assert author.books[0].author is author + + book = OverridenSQLAlchemyFactory.create_factory(Book).create_sync() + assert book.author is not None + assert book.author.books == [book] + + BaseFactory._base_factories.remove(OverridenSQLAlchemyFactory) + + async def test_invalid_peristence_config_raises() -> None: class AuthorFactory(SQLAlchemyFactory[Author]): __model__ = Author diff --git a/tests/test_beanie_factory.py b/tests/test_beanie_factory.py index 280ecb63..f48999a6 100644 --- a/tests/test_beanie_factory.py +++ b/tests/test_beanie_factory.py @@ -48,7 +48,7 @@ class MyOtherFactory(BeanieDocumentFactory): @pytest.fixture() async def beanie_init(mongo_connection: AsyncMongoMockClient) -> None: - await init_beanie(database=mongo_connection.db_name, document_models=[MyDocument, MyOtherDocument]) # type: ignore + await init_beanie(database=mongo_connection.db_name, document_models=[MyDocument, MyOtherDocument]) async def test_handling_of_beanie_types(beanie_init: Callable) -> None: diff --git a/tests/test_passing_build_args_to_child_factories.py b/tests/test_passing_build_args_to_child_factories.py index 1ba4ff2a..eab2c936 100644 --- a/tests/test_passing_build_args_to_child_factories.py +++ b/tests/test_passing_build_args_to_child_factories.py @@ -64,7 +64,7 @@ def test_factory_child_model_list() -> None: }, } - person = PersonFactory.build(factory_use_construct=False, **data) + person = PersonFactory.build(factory_use_construct=False, **data) # type: ignore[arg-type] assert person.name == "Jean" assert len(person.pets) == 2 @@ -174,7 +174,7 @@ class D(BaseModel): class DFactory(ModelFactory): __model__ = D - build_result = DFactory.build(factory_use_construct=False, **{"c": {"b": {"a": {"name": "test"}}}}) + build_result = DFactory.build(factory_use_construct=False, **{"c": {"b": {"a": {"name": "test"}}}}) # type: ignore[arg-type] assert build_result assert build_result.c.b.a.name == "test" diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py new file mode 100644 index 00000000..1a2d1ba0 --- /dev/null +++ b/tests/test_recursive_models.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from polyfactory.factories.dataclass_factory import DataclassFactory + + +@dataclass +class Node: + a: int + child: Node | None + + +def test_recusive_model() -> None: + factory = DataclassFactory.create_factory(Node) + assert factory.build().child is None + assert factory.build(child={"child": None}).child.child is None # type: ignore[union-attr] + + +@dataclass +class Author: + name: str + books: list[Book] + + +@dataclass +class Book: + name: str + author: Author + + +def test_recusive_list_model() -> None: + factory = DataclassFactory.create_factory(Author) + assert factory.build().books[0].author is None + assert factory.build(books=[]).books == [] + + book_factory = DataclassFactory.create_factory(Book) + assert book_factory.build().author.books == [] + assert book_factory.build(author=None).author is None From 1ff9dc921c4200aea09b087acc73e7ba061f8b63 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 2 Jan 2024 11:08:46 +0000 Subject: [PATCH 02/11] fix: avoid 3.10+ set syntax --- polyfactory/factories/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 0d46be7a..e109d7fd 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress @@ -98,8 +99,10 @@ class BuildContext(TypedDict): def _get_build_context(build_context: BuildContext | None) -> BuildContext: - build_context = build_context or {"seen_models": set()} - return build_context.copy() + if build_context is None: + return {"seen_models": set()} + + return copy.deepcopy(build_context) class BaseFactory(ABC, Generic[T]): @@ -857,7 +860,7 @@ def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any """ build_context = _get_build_context(build_context) - build_context["seen_models"] = build_context["seen_models"] | {cls.__model__} + build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -910,7 +913,7 @@ def process_kwargs_coverage( """ build_context = _get_build_context(build_context) - build_context["seen_models"] = build_context["seen_models"] | {cls.__model__} + build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} From d197838632966c5a4a8d36decfd0abd3328cecdb Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 2 Jan 2024 11:13:33 +0000 Subject: [PATCH 03/11] fix: avoid 3.10+ set syntax --- tests/test_recursive_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 1a2d1ba0..02a12c71 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional from polyfactory.factories.dataclass_factory import DataclassFactory @@ -8,7 +9,7 @@ @dataclass class Node: a: int - child: Node | None + child: Optional[Node] # noqa: UP007 def test_recusive_model() -> None: From 74a3ad855110363b563bb883dd3f1294f7203ca3 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 2 Jan 2024 11:16:21 +0000 Subject: [PATCH 04/11] fix: avoid 3.9+ typing syntax --- tests/test_recursive_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 02a12c71..46edf0c9 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional +from typing import List, Optional from polyfactory.factories.dataclass_factory import DataclassFactory @@ -21,7 +21,7 @@ def test_recusive_model() -> None: @dataclass class Author: name: str - books: list[Book] + books: List[Book] # noqa: UP006 @dataclass From 82ae1ffc8df9ad6cdfec6aa6e37c0cf47d906c55 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 2 Jan 2024 11:23:46 +0000 Subject: [PATCH 05/11] fix: revert lock file changes --- pdm.lock | 147 ++++++++++++++++--------------------------------- pyproject.toml | 4 +- 2 files changed, 47 insertions(+), 104 deletions(-) diff --git a/pdm.lock b/pdm.lock index 6d1d6728..20e229a1 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "msgspec", "sqlalchemy", "lint", "attrs", "full", "docs", "odmantic", "pydantic", "test", "dev", "beanie"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:62840909635332a37b19c20794d822e617c142fa856f00d88098338d405ef77c" +content_hash = "sha256:fdbe2e08c7fa2713913bf0f53e45ac79993264869568e92094eb4acebedf6274" [[package]] name = "accessible-pygments" @@ -283,7 +283,7 @@ version = "23.12.0" requires_python = ">=3.8" summary = "The uncompromising code formatter." dependencies = [ - "aiohttp>=3.7.4; sys_platform != \"win32\" or implementation_name != \"pypy\"", + "aiohttp>=3.7.4; sys_platform != \"win32\" or implementation_name != \"pypy\" and extra == \"d\"", "click>=8.0.0", "mypy-extensions>=0.4.3", "packaging>=22.0", @@ -990,7 +990,7 @@ files = [ [[package]] name = "litestar-sphinx-theme" version = "0.2.0" -requires_python = "<4.0,>=3.8" +requires_python = ">=3.8,<4.0" git = "https://github.com/litestar-org/litestar-sphinx-theme.git" revision = "c5ce66aadc8f910c24f54bf0d172798c237a67eb" summary = "A Sphinx theme for the Litestar organization" @@ -2205,7 +2205,7 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.24" +version = "2.0.23" requires_python = ">=3.7" summary = "Database Abstraction Library" dependencies = [ @@ -2213,103 +2213,48 @@ dependencies = [ "typing-extensions>=4.2.0", ] files = [ - {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f801d85ba4753d4ed97181d003e5d3fa330ac7c4587d131f61d7f968f416862"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b35c35e3923ade1e7ac44e150dec29f5863513246c8bf85e2d7d313e3832bcfb"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9b3fd5eca3c0b137a5e0e468e24ca544ed8ca4783e0e55341b7ed2807518ee"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6209e689d0ff206c40032b6418e3cfcfc5af044b3f66e381d7f1ae301544b4"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:37e89d965b52e8b20571b5d44f26e2124b26ab63758bf1b7598a0e38fb2c4005"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6910eb4ea90c0889f363965cd3c8c45a620ad27b526a7899f0054f6c1b9219e"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-win32.whl", hash = "sha256:d8e7e8a150e7b548e7ecd6ebb9211c37265991bf2504297d9454e01b58530fc6"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-win_amd64.whl", hash = "sha256:396f05c552f7fa30a129497c41bef5b4d1423f9af8fe4df0c3dcd38f3e3b9a14"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:adbd67dac4ebf54587198b63cd30c29fd7eafa8c0cab58893d9419414f8efe4b"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0f611b431b84f55779cbb7157257d87b4a2876b067c77c4f36b15e44ced65e2"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56a0e90a959e18ac5f18c80d0cad9e90cb09322764f536e8a637426afb1cae2f"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6db686a1d9f183c639f7e06a2656af25d4ed438eda581de135d15569f16ace33"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f0cc0b486a56dff72dddae6b6bfa7ff201b0eeac29d4bc6f0e9725dc3c360d71"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a1d4856861ba9e73bac05030cec5852eabfa9ef4af8e56c19d92de80d46fc34"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-win32.whl", hash = "sha256:a3c2753bf4f48b7a6024e5e8a394af49b1b12c817d75d06942cae03d14ff87b3"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-win_amd64.whl", hash = "sha256:38732884eabc64982a09a846bacf085596ff2371e4e41d20c0734f7e50525d01"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9f992e0f916201731993eab8502912878f02287d9f765ef843677ff118d0e0b1"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2587e108463cc2e5b45a896b2e7cc8659a517038026922a758bde009271aed11"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bb7cedcddffca98c40bb0becd3423e293d1fef442b869da40843d751785beb3"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83fa6df0e035689df89ff77a46bf8738696785d3156c2c61494acdcddc75c69d"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc889fda484d54d0b31feec409406267616536d048a450fc46943e152700bb79"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57ef6f2cb8b09a042d0dbeaa46a30f2df5dd1e1eb889ba258b0d5d7d6011b81c"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-win32.whl", hash = "sha256:ea490564435b5b204d8154f0e18387b499ea3cedc1e6af3b3a2ab18291d85aa7"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-win_amd64.whl", hash = "sha256:ccfd336f96d4c9bbab0309f2a565bf15c468c2d8b2d277a32f89c5940f71fcf9"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9bafaa05b19dc07fa191c1966c5e852af516840b0d7b46b7c3303faf1a349bc9"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e69290b921b7833c04206f233d6814c60bee1d135b09f5ae5d39229de9b46cd4"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8398593ccc4440ce6dffcc4f47d9b2d72b9fe7112ac12ea4a44e7d4de364db1"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f073321a79c81e1a009218a21089f61d87ee5fa3c9563f6be94f8b41ff181812"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9036ebfd934813990c5b9f71f297e77ed4963720db7d7ceec5a3fdb7cd2ef6ce"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcf84fe93397a0f67733aa2a38ed4eab9fc6348189fc950e656e1ea198f45668"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-win32.whl", hash = "sha256:6f5e75de91c754365c098ac08c13fdb267577ce954fa239dd49228b573ca88d7"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-win_amd64.whl", hash = "sha256:9f29c7f0f4b42337ec5a779e166946a9f86d7d56d827e771b69ecbdf426124ac"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:07cc423892f2ceda9ae1daa28c0355757f362ecc7505b1ab1a3d5d8dc1c44ac6"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a479aa1ab199178ff1956b09ca8a0693e70f9c762875d69292d37049ffd0d8f"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b8d0e8578e7f853f45f4512b5c920f6a546cd4bed44137460b2a56534644205"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17e7e27af178d31b436dda6a596703b02a89ba74a15e2980c35ecd9909eea3a"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1ca7903d5e7db791a355b579c690684fac6304478b68efdc7f2ebdcfe770d8d7"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db09e424d7bb89b6215a184ca93b4f29d7f00ea261b787918a1af74143b98c06"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-win32.whl", hash = "sha256:a5cd7d30e47f87b21362beeb3e86f1b5886e7d9b0294b230dde3d3f4a1591375"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-win_amd64.whl", hash = "sha256:7ae5d44517fe81079ce75cf10f96978284a6db2642c5932a69c82dbae09f009a"}, - {file = "SQLAlchemy-2.0.24-py3-none-any.whl", hash = "sha256:8f358f5cfce04417b6ff738748ca4806fe3d3ae8040fb4e6a0c9a6973ccf9b6e"}, - {file = "SQLAlchemy-2.0.24.tar.gz", hash = "sha256:6db97656fd3fe3f7e5b077f12fa6adb5feb6e0b567a3e99f47ecf5f7ea0a09e3"}, -] - -[[package]] -name = "sqlalchemy" -version = "2.0.24" -extras = ["asyncio"] -requires_python = ">=3.7" -summary = "Database Abstraction Library" -dependencies = [ - "greenlet!=0.4.17", - "sqlalchemy==2.0.24", -] -files = [ - {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f801d85ba4753d4ed97181d003e5d3fa330ac7c4587d131f61d7f968f416862"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b35c35e3923ade1e7ac44e150dec29f5863513246c8bf85e2d7d313e3832bcfb"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9b3fd5eca3c0b137a5e0e468e24ca544ed8ca4783e0e55341b7ed2807518ee"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6209e689d0ff206c40032b6418e3cfcfc5af044b3f66e381d7f1ae301544b4"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:37e89d965b52e8b20571b5d44f26e2124b26ab63758bf1b7598a0e38fb2c4005"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6910eb4ea90c0889f363965cd3c8c45a620ad27b526a7899f0054f6c1b9219e"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-win32.whl", hash = "sha256:d8e7e8a150e7b548e7ecd6ebb9211c37265991bf2504297d9454e01b58530fc6"}, - {file = "SQLAlchemy-2.0.24-cp310-cp310-win_amd64.whl", hash = "sha256:396f05c552f7fa30a129497c41bef5b4d1423f9af8fe4df0c3dcd38f3e3b9a14"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:adbd67dac4ebf54587198b63cd30c29fd7eafa8c0cab58893d9419414f8efe4b"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0f611b431b84f55779cbb7157257d87b4a2876b067c77c4f36b15e44ced65e2"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56a0e90a959e18ac5f18c80d0cad9e90cb09322764f536e8a637426afb1cae2f"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6db686a1d9f183c639f7e06a2656af25d4ed438eda581de135d15569f16ace33"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f0cc0b486a56dff72dddae6b6bfa7ff201b0eeac29d4bc6f0e9725dc3c360d71"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a1d4856861ba9e73bac05030cec5852eabfa9ef4af8e56c19d92de80d46fc34"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-win32.whl", hash = "sha256:a3c2753bf4f48b7a6024e5e8a394af49b1b12c817d75d06942cae03d14ff87b3"}, - {file = "SQLAlchemy-2.0.24-cp311-cp311-win_amd64.whl", hash = "sha256:38732884eabc64982a09a846bacf085596ff2371e4e41d20c0734f7e50525d01"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9f992e0f916201731993eab8502912878f02287d9f765ef843677ff118d0e0b1"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2587e108463cc2e5b45a896b2e7cc8659a517038026922a758bde009271aed11"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bb7cedcddffca98c40bb0becd3423e293d1fef442b869da40843d751785beb3"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83fa6df0e035689df89ff77a46bf8738696785d3156c2c61494acdcddc75c69d"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc889fda484d54d0b31feec409406267616536d048a450fc46943e152700bb79"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57ef6f2cb8b09a042d0dbeaa46a30f2df5dd1e1eb889ba258b0d5d7d6011b81c"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-win32.whl", hash = "sha256:ea490564435b5b204d8154f0e18387b499ea3cedc1e6af3b3a2ab18291d85aa7"}, - {file = "SQLAlchemy-2.0.24-cp312-cp312-win_amd64.whl", hash = "sha256:ccfd336f96d4c9bbab0309f2a565bf15c468c2d8b2d277a32f89c5940f71fcf9"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9bafaa05b19dc07fa191c1966c5e852af516840b0d7b46b7c3303faf1a349bc9"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e69290b921b7833c04206f233d6814c60bee1d135b09f5ae5d39229de9b46cd4"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8398593ccc4440ce6dffcc4f47d9b2d72b9fe7112ac12ea4a44e7d4de364db1"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f073321a79c81e1a009218a21089f61d87ee5fa3c9563f6be94f8b41ff181812"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9036ebfd934813990c5b9f71f297e77ed4963720db7d7ceec5a3fdb7cd2ef6ce"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcf84fe93397a0f67733aa2a38ed4eab9fc6348189fc950e656e1ea198f45668"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-win32.whl", hash = "sha256:6f5e75de91c754365c098ac08c13fdb267577ce954fa239dd49228b573ca88d7"}, - {file = "SQLAlchemy-2.0.24-cp38-cp38-win_amd64.whl", hash = "sha256:9f29c7f0f4b42337ec5a779e166946a9f86d7d56d827e771b69ecbdf426124ac"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:07cc423892f2ceda9ae1daa28c0355757f362ecc7505b1ab1a3d5d8dc1c44ac6"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a479aa1ab199178ff1956b09ca8a0693e70f9c762875d69292d37049ffd0d8f"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b8d0e8578e7f853f45f4512b5c920f6a546cd4bed44137460b2a56534644205"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17e7e27af178d31b436dda6a596703b02a89ba74a15e2980c35ecd9909eea3a"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1ca7903d5e7db791a355b579c690684fac6304478b68efdc7f2ebdcfe770d8d7"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db09e424d7bb89b6215a184ca93b4f29d7f00ea261b787918a1af74143b98c06"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-win32.whl", hash = "sha256:a5cd7d30e47f87b21362beeb3e86f1b5886e7d9b0294b230dde3d3f4a1591375"}, - {file = "SQLAlchemy-2.0.24-cp39-cp39-win_amd64.whl", hash = "sha256:7ae5d44517fe81079ce75cf10f96978284a6db2642c5932a69c82dbae09f009a"}, - {file = "SQLAlchemy-2.0.24-py3-none-any.whl", hash = "sha256:8f358f5cfce04417b6ff738748ca4806fe3d3ae8040fb4e6a0c9a6973ccf9b6e"}, - {file = "SQLAlchemy-2.0.24.tar.gz", hash = "sha256:6db97656fd3fe3f7e5b077f12fa6adb5feb6e0b567a3e99f47ecf5f7ea0a09e3"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"}, + {file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"}, + {file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index b144a0e8..9644eb01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,7 @@ dependencies = [ ] [project.optional-dependencies] -sqlalchemy = [ - "sqlalchemy[asyncio]>=1.4.29", -] +sqlalchemy = ["sqlalchemy>=1.4.29",] pydantic = ["pydantic[email]",] msgspec = ["msgspec",] odmantic = ["odmantic<1.0.0", "pydantic[email]",] From 66475dba47c19883c2710af8d91fa4ea0f92be38 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 3 Jan 2024 14:01:24 +0000 Subject: [PATCH 06/11] fix: address PR comments --- polyfactory/factories/base.py | 31 +++++++---- polyfactory/factories/pydantic_factory.py | 4 +- polyfactory/utils/helpers.py | 13 +---- tests/test_recursive_models.py | 67 ++++++++++++++++++++--- 4 files changed, 85 insertions(+), 30 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index e109d7fd..4ea192dd 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -288,12 +288,14 @@ def _handle_factory_field( """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return field_value.build(build_context, **field_build_parameters) + return field_value.build(build_context=build_context, **field_build_parameters) if isinstance(field_build_parameters, Sequence): - return [field_value.build(build_context, **parameter) for parameter in field_build_parameters] + return [ + field_value.build(build_context=build_context, **parameter) for parameter in field_build_parameters + ] - return field_value.build(build_context) + return field_value.build(build_context=build_context) if isinstance(field_value, Use): return field_value.to_value() @@ -628,6 +630,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 :param field_meta: FieldMeta instance. :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. :returns: An arbitrary value. @@ -652,7 +655,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 if BaseFactory.is_factory_type(annotation=unwrapped_annotation): if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: - return None + return None if is_optional(field_meta.annotation) else Ignore return cls._get_or_create_factory(model=unwrapped_annotation).build( build_context, @@ -662,18 +665,21 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): factory = cls._get_or_create_factory(model=field_meta.type_args[0]) if isinstance(field_build_parameters, Sequence): - return [factory.build(build_context, **field_parameters) for field_parameters in field_build_parameters] + return [ + factory.build(build_context=build_context, **field_parameters) + for field_parameters in field_build_parameters + ] if field_meta.type_args[0] in build_context["seen_models"]: return [] if not cls.__randomize_collection_length__: - return [factory.build(build_context)] + return [factory.build(build_context=build_context)] batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) return factory.batch(size=batch_size, build_context=build_context) - if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): if cls.__randomize_collection_length__: collection_type = get_collection_type(unwrapped_annotation) if collection_type != dict: @@ -695,7 +701,8 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return handle_collection_type(field_meta, origin, cls) if is_union(field_meta.annotation) and field_meta.children: - return cls.get_field_value(cls.__random__.choice(field_meta.children)) + children = [child for child in field_meta.children if child.annotation not in build_context["seen_models"]] + return cls.get_field_value(cls.__random__.choice(children)) if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): return create_random_string(cls.__random__, min_length=1, max_length=10) @@ -888,11 +895,15 @@ def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any ) continue - result[field_meta.name] = cls.get_field_value( + field_result = cls.get_field_value( field_meta, field_build_parameters=field_build_parameters, build_context=build_context, ) + if field_result is Ignore: + continue + + result[field_meta.name] = field_result for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) @@ -971,7 +982,7 @@ def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: A :returns: A list of instances of type T. """ - return [cls.build(build_context, **kwargs) for _ in range(size)] + return [cls.build(build_context=build_context, **kwargs) for _ in range(size)] @classmethod def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 8998bfa6..8eba5d07 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -368,10 +368,10 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> return super().get_constrained_field_value(annotation, field_meta) @classmethod - def build( + def build( # type: ignore[override] cls, - build_context: BuildContext | None = None, factory_use_construct: bool = False, + build_context: BuildContext | None = None, **kwargs: Any, ) -> T: """Build an instance of the factory's __model__ diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index c28c0626..a6432b58 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -11,13 +11,7 @@ from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING -from polyfactory.utils.predicates import ( - is_annotated, - is_new_type, - is_optional, - is_safe_subclass, - is_union, -) +from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_safe_subclass, is_union if TYPE_CHECKING: from random import Random @@ -71,15 +65,14 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any: :returns: The unwrapped annotation. """ - while is_optional(annotation) or is_union(annotation) or is_new_type(annotation) or is_annotated(annotation): + while is_optional(annotation) or is_new_type(annotation) or is_annotated(annotation): if is_new_type(annotation): annotation = unwrap_new_type(annotation) elif is_optional(annotation): annotation = unwrap_optional(annotation) elif is_annotated(annotation): annotation = unwrap_annotated(annotation, random=random)[0] - else: - annotation = unwrap_union(annotation, random=random) + return annotation diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 46edf0c9..6da76644 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -1,21 +1,69 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import List, Optional +from dataclasses import dataclass, field +from typing import Any, List, Optional, Union + +import pytest +from pydantic import BaseModel, Field from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.pydantic_factory import ModelFactory + + +class _Sentinel: + ... @dataclass class Node: - a: int - child: Optional[Node] # noqa: UP007 + value: int + union_child: Union[Node, int] # noqa: UP007 + list_child: List[Node] # noqa: UP006 + optional_child: Optional[Node] # noqa: UP007 + child: Node = field(default=_Sentinel) # type: ignore[assignment] + + def __post_init__(self) -> None: + # Emulate recursive models set by external init, e.g. ORM relationships + if self.child is _Sentinel: # type: ignore[comparison-overlap] + self.child = self def test_recusive_model() -> None: factory = DataclassFactory.create_factory(Node) - assert factory.build().child is None - assert factory.build(child={"child": None}).child.child is None # type: ignore[union-attr] + + result = factory.build() + assert result.child is result, "Default is not used" + assert isinstance(result.union_child, int) + assert result.optional_child is None + assert result.list_child == [] + + assert factory.build(child={"child": None}).child.child is None + + +class PydanticNode(BaseModel): + value: int + union_child: Union[PydanticNode, int] # noqa: UP007 + list_child: List[PydanticNode] # noqa: UP006 + optional_child: Union[PydanticNode, None] # noqa: UP007 + child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment] + + def model_post_init(self, context: Any) -> None: + # Emulate recursive models set by external init, e.g. ORM relationships + if self.child is _Sentinel: + self.child = self + + +@pytest.mark.parametrize("factory_use_construct", (True, False)) +def test_recursive_pydantic_models(factory_use_construct: bool) -> None: + factory = ModelFactory.create_factory(PydanticNode) + + result = factory.build(factory_use_construct) + assert result.child is result, "Default is not used" + assert isinstance(result.union_child, int) + assert result.optional_child is None + assert result.list_child == [] + + assert factory.build(child={"child": None}).child.child is None @dataclass @@ -24,15 +72,18 @@ class Author: books: List[Book] # noqa: UP006 +_DEFAULT_AUTHOR = Author(name="default", books=[]) + + @dataclass class Book: name: str - author: Author + author: Author = field(default_factory=lambda: _DEFAULT_AUTHOR) def test_recusive_list_model() -> None: factory = DataclassFactory.create_factory(Author) - assert factory.build().books[0].author is None + assert factory.build().books[0].author is _DEFAULT_AUTHOR assert factory.build(books=[]).books == [] book_factory = DataclassFactory.create_factory(Book) From a22a980029f594ce7171651d568c5d8db8fa9c15 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 3 Jan 2024 14:46:46 +0000 Subject: [PATCH 07/11] fix: adjust tests --- polyfactory/factories/base.py | 2 +- tests/test_dicts.py | 6 +++--- tests/test_recursive_models.py | 9 ++------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 4ea192dd..c8d62782 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -700,7 +700,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return handle_collection_type(field_meta, origin, cls) - if is_union(field_meta.annotation) and field_meta.children: + if is_union(unwrapped_annotation) and field_meta.children: children = [child for child in field_meta.children if child.annotation not in build_context["seen_models"]] return cls.get_field_value(cls.__random__.choice(children)) diff --git a/tests/test_dicts.py b/tests/test_dicts.py index 0eae3da0..d9ad95b2 100644 --- a/tests/test_dicts.py +++ b/tests/test_dicts.py @@ -36,10 +36,10 @@ class MyClass(BaseModel): class MyClassFactory(ModelFactory[MyClass]): __model__ = MyClass - MyClassFactory.seed_random(100) + MyClassFactory.seed_random(10) test_obj_1 = MyClassFactory.build() test_obj_2 = MyClassFactory.build() - assert isinstance(next(iter(test_obj_1.val.values())), str) - assert isinstance(next(iter(test_obj_2.val.values())), int) + assert isinstance(next(iter(test_obj_1.val.values())), int) + assert isinstance(next(iter(test_obj_2.val.values())), str) diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 6da76644..95456ac4 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from typing import List, Optional, Union import pytest from pydantic import BaseModel, Field @@ -47,18 +47,13 @@ class PydanticNode(BaseModel): optional_child: Union[PydanticNode, None] # noqa: UP007 child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment] - def model_post_init(self, context: Any) -> None: - # Emulate recursive models set by external init, e.g. ORM relationships - if self.child is _Sentinel: - self.child = self - @pytest.mark.parametrize("factory_use_construct", (True, False)) def test_recursive_pydantic_models(factory_use_construct: bool) -> None: factory = ModelFactory.create_factory(PydanticNode) result = factory.build(factory_use_construct) - assert result.child is result, "Default is not used" + assert result.child is _Sentinel, "Default is not used" assert isinstance(result.union_child, int) assert result.optional_child is None assert result.list_child == [] From 7f2d81bda23b7414b1b49283821db9b1efa0b07c Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 4 Jan 2024 18:22:32 +0000 Subject: [PATCH 08/11] fix: make interface backwards compatible. Favour Null over Ignore --- polyfactory/factories/base.py | 44 ++++++++++++----------- polyfactory/factories/pydantic_factory.py | 10 ++++-- tests/test_recursive_models.py | 2 -- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index c8d62782..a2884888 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -24,6 +24,8 @@ from pathlib import Path from random import Random +from polyfactory.field_meta import Null + try: from types import NoneType except ImportError: @@ -288,14 +290,14 @@ def _handle_factory_field( """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return field_value.build(build_context=build_context, **field_build_parameters) + return field_value.build(_build_context=build_context, **field_build_parameters) if isinstance(field_build_parameters, Sequence): return [ - field_value.build(build_context=build_context, **parameter) for parameter in field_build_parameters + field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters ] - return field_value.build(build_context=build_context) + return field_value.build(_build_context=build_context) if isinstance(field_value, Use): return field_value.to_value() @@ -655,10 +657,10 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 if BaseFactory.is_factory_type(annotation=unwrapped_annotation): if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: - return None if is_optional(field_meta.annotation) else Ignore + return None if is_optional(field_meta.annotation) else Null return cls._get_or_create_factory(model=unwrapped_annotation).build( - build_context, + _build_context=build_context, **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), ) @@ -666,7 +668,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 factory = cls._get_or_create_factory(model=field_meta.type_args[0]) if isinstance(field_build_parameters, Sequence): return [ - factory.build(build_context=build_context, **field_parameters) + factory.build(_build_context=build_context, **field_parameters) for field_parameters in field_build_parameters ] @@ -674,10 +676,10 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return [] if not cls.__randomize_collection_length__: - return [factory.build(build_context=build_context)] + return [factory.build(_build_context=build_context)] batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) - return factory.batch(size=batch_size, build_context=build_context) + return factory.batch(size=batch_size, _build_context=build_context) if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): if cls.__randomize_collection_length__: @@ -858,7 +860,7 @@ def _check_declared_fields_exist_in_model(cls) -> None: raise ConfigurationException(error_message) @classmethod - def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]: + def process_kwargs(cls, *, _build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. @@ -866,8 +868,8 @@ def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any :returns: A dictionary of build results. """ - build_context = _get_build_context(build_context) - build_context["seen_models"].add(cls.__model__) + _build_context = _get_build_context(_build_context) + _build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -891,16 +893,16 @@ def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any result[field_meta.name] = cls._handle_factory_field( field_value=field_value, field_build_parameters=field_build_parameters, - build_context=build_context, + build_context=_build_context, ) continue field_result = cls.get_field_value( field_meta, field_build_parameters=field_build_parameters, - build_context=build_context, + build_context=_build_context, ) - if field_result is Ignore: + if field_result is Null: continue result[field_meta.name] = field_result @@ -913,7 +915,7 @@ def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any @classmethod def process_kwargs_coverage( cls, - build_context: BuildContext | None = None, + _build_context: BuildContext | None = None, **kwargs: Any, ) -> abc.Iterable[dict[str, Any]]: """Process the given kwargs and generate values for the factory's model. @@ -923,8 +925,8 @@ def process_kwargs_coverage( :returns: A dictionary of build results. """ - build_context = _get_build_context(build_context) - build_context["seen_models"].add(cls.__model__) + _build_context = _get_build_context(_build_context) + _build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -962,7 +964,7 @@ def process_kwargs_coverage( yield resolved @classmethod - def build(cls, build_context: BuildContext | None = None, **kwargs: Any) -> T: + def build(cls, *, _build_context: BuildContext | None = None, **kwargs: Any) -> T: """Build an instance of the factory's __model__ :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. @@ -970,10 +972,10 @@ def build(cls, build_context: BuildContext | None = None, **kwargs: Any) -> T: :returns: An instance of type T. """ - return cast("T", cls.__model__(**cls.process_kwargs(build_context, **kwargs))) + return cast("T", cls.__model__(**cls.process_kwargs(_build_context=_build_context, **kwargs))) @classmethod - def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: Any) -> list[T]: + def batch(cls, size: int, _build_context: BuildContext | None = None, **kwargs: Any) -> list[T]: """Build a batch of size n of the factory's Meta.model. :param size: Size of the batch. @@ -982,7 +984,7 @@ def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: A :returns: A list of instances of type T. """ - return [cls.build(build_context=build_context, **kwargs) for _ in range(size)] + return [cls.build(_build_context=_build_context, **kwargs) for _ in range(size)] @classmethod def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 8eba5d07..f0683834 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -368,10 +368,11 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> return super().get_constrained_field_value(annotation, field_meta) @classmethod - def build( # type: ignore[override] + def build( cls, factory_use_construct: bool = False, - build_context: BuildContext | None = None, + *, + _build_context: BuildContext | None = None, **kwargs: Any, ) -> T: """Build an instance of the factory's __model__ @@ -383,7 +384,10 @@ def build( # type: ignore[override] :returns: An instance of type T. """ - processed_kwargs = cls.process_kwargs(build_context, **kwargs) + processed_kwargs = cls.process_kwargs( + _build_context=_build_context, + **kwargs, + ) if factory_use_construct: return ( diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 95456ac4..9f8c4ad1 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -58,8 +58,6 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None: assert result.optional_child is None assert result.list_child == [] - assert factory.build(child={"child": None}).child.child is None - @dataclass class Author: From 45c3e3edffc63ef0a3db5af29a2c9086f1fa96b0 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 9 Jan 2024 17:14:17 +0000 Subject: [PATCH 09/11] fix: PR comments --- polyfactory/factories/base.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index a2884888..8dcdfe95 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -308,7 +308,12 @@ def _handle_factory_field( return field_value() if callable(field_value) else field_value @classmethod - def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + def _handle_factory_field_coverage( + cls, + field_value: Any, + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: """Handle a value defined on the factory class itself. :param field_value: A value defined as an attribute on the factory class. @@ -318,10 +323,13 @@ def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return CoverageContainer(field_value.coverage(**field_build_parameters)) + return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) if isinstance(field_build_parameters, Sequence): - return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters] + return [ + CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) + for parameter in field_build_parameters + ] return CoverageContainer(field_value.coverage()) @@ -728,11 +736,13 @@ def get_field_value_coverage( # noqa: C901 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, ) -> Iterable[Any]: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. :returns: An iterable of values. @@ -760,6 +770,7 @@ def get_field_value_coverage( # noqa: C901 elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): yield CoverageContainer( cls._get_or_create_factory(model=unwrapped_annotation).coverage( + _build_context=build_context, **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), ), ) @@ -915,12 +926,14 @@ def process_kwargs(cls, *, _build_context: BuildContext | None = None, **kwargs: @classmethod def process_kwargs_coverage( cls, + *, _build_context: BuildContext | None = None, **kwargs: Any, ) -> abc.Iterable[dict[str, Any]]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. + :param build_context: BuildContext data for current build. :returns: A dictionary of build results. @@ -951,11 +964,16 @@ def process_kwargs_coverage( result[field_meta.name] = cls._handle_factory_field_coverage( field_value=field_value, field_build_parameters=field_build_parameters, + build_context=_build_context, ) continue result[field_meta.name] = CoverageContainer( - cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters), + cls.get_field_value_coverage( + field_meta, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ), ) for resolved in resolve_kwargs_coverage(result): @@ -987,7 +1005,7 @@ def batch(cls, size: int, _build_context: BuildContext | None = None, **kwargs: return [cls.build(_build_context=_build_context, **kwargs) for _ in range(size)] @classmethod - def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: + def coverage(cls, _build_context: BuildContext | None = None, **kwargs: Any) -> abc.Iterator[T]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -995,7 +1013,7 @@ def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: :returns: A iterator of instances of type T. """ - for data in cls.process_kwargs_coverage(**kwargs): + for data in cls.process_kwargs_coverage(_build_context=_build_context, **kwargs): instance = cls.__model__(**data) yield cast("T", instance) From b444845503527f88a21307ae043c69da0b3dba42 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 15 Jan 2024 11:22:06 +0000 Subject: [PATCH 10/11] fix: PR updates --- .../sqlalchemy_factory/conftest.py | 11 ++++++++ .../sqlalchemy_factory/test_example_4.py | 2 -- polyfactory/factories/base.py | 25 ++++++++----------- polyfactory/factories/pydantic_factory.py | 9 ++----- ...t_passing_build_args_to_child_factories.py | 4 +-- 5 files changed, 25 insertions(+), 26 deletions(-) create mode 100644 docs/examples/library_factories/sqlalchemy_factory/conftest.py diff --git a/docs/examples/library_factories/sqlalchemy_factory/conftest.py b/docs/examples/library_factories/sqlalchemy_factory/conftest.py new file mode 100644 index 00000000..2227e64c --- /dev/null +++ b/docs/examples/library_factories/sqlalchemy_factory/conftest.py @@ -0,0 +1,11 @@ +from collections.abc import Iterable + +import pytest + +from docs.examples.library_factories.sqlalchemy_factory.test_example_4 import BaseFactory + + +@pytest.fixture(scope="module") +def _remove_default_factories() -> Iterable[None]: + yield + BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001 diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py index b0af4135..0d616a38 100644 --- a/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py @@ -56,5 +56,3 @@ def test_custom_sqla_factory() -> None: book = BaseFactory.create_factory(Book).create_sync() assert book.id is not None assert book.author.books == [book] - - BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001 diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 7ef1b0c5..d49692f3 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -889,7 +889,7 @@ def _check_declared_fields_exist_in_model(cls) -> None: raise ConfigurationException(error_message) @classmethod - def process_kwargs(cls, *, _build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]: + def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. @@ -897,7 +897,7 @@ def process_kwargs(cls, *, _build_context: BuildContext | None = None, **kwargs: :returns: A dictionary of build results. """ - _build_context = _get_build_context(_build_context) + _build_context = _get_build_context(kwargs.pop("_build_context", None)) _build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} @@ -942,12 +942,7 @@ def process_kwargs(cls, *, _build_context: BuildContext | None = None, **kwargs: return result @classmethod - def process_kwargs_coverage( - cls, - *, - _build_context: BuildContext | None = None, - **kwargs: Any, - ) -> abc.Iterable[dict[str, Any]]: + def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. @@ -956,7 +951,7 @@ def process_kwargs_coverage( :returns: A dictionary of build results. """ - _build_context = _get_build_context(_build_context) + _build_context = _get_build_context(kwargs.pop("_build_context", None)) _build_context["seen_models"].add(cls.__model__) result: dict[str, Any] = {**kwargs} @@ -1000,7 +995,7 @@ def process_kwargs_coverage( yield resolved @classmethod - def build(cls, *, _build_context: BuildContext | None = None, **kwargs: Any) -> T: + def build(cls, **kwargs: Any) -> T: """Build an instance of the factory's __model__ :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. @@ -1008,10 +1003,10 @@ def build(cls, *, _build_context: BuildContext | None = None, **kwargs: Any) -> :returns: An instance of type T. """ - return cast("T", cls.__model__(**cls.process_kwargs(_build_context=_build_context, **kwargs))) + return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) @classmethod - def batch(cls, size: int, _build_context: BuildContext | None = None, **kwargs: Any) -> list[T]: + def batch(cls, size: int, **kwargs: Any) -> list[T]: """Build a batch of size n of the factory's Meta.model. :param size: Size of the batch. @@ -1020,10 +1015,10 @@ def batch(cls, size: int, _build_context: BuildContext | None = None, **kwargs: :returns: A list of instances of type T. """ - return [cls.build(_build_context=_build_context, **kwargs) for _ in range(size)] + return [cls.build(**kwargs) for _ in range(size)] @classmethod - def coverage(cls, _build_context: BuildContext | None = None, **kwargs: Any) -> abc.Iterator[T]: + def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -1031,7 +1026,7 @@ def coverage(cls, _build_context: BuildContext | None = None, **kwargs: Any) -> :returns: A iterator of instances of type T. """ - for data in cls.process_kwargs_coverage(_build_context=_build_context, **kwargs): + for data in cls.process_kwargs_coverage(**kwargs): instance = cls.__model__(**data) yield cast("T", instance) diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index f0683834..16218dc4 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -13,7 +13,7 @@ from polyfactory.collection_extender import CollectionExtender from polyfactory.constants import DEFAULT_RANDOM from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory, BuildContext +from polyfactory.factories.base import BaseFactory from polyfactory.field_meta import Constraints, FieldMeta, Null from polyfactory.utils.deprecation import check_for_deprecated_parameters from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional @@ -371,8 +371,6 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> def build( cls, factory_use_construct: bool = False, - *, - _build_context: BuildContext | None = None, **kwargs: Any, ) -> T: """Build an instance of the factory's __model__ @@ -384,10 +382,7 @@ def build( :returns: An instance of type T. """ - processed_kwargs = cls.process_kwargs( - _build_context=_build_context, - **kwargs, - ) + processed_kwargs = cls.process_kwargs(**kwargs) if factory_use_construct: return ( diff --git a/tests/test_passing_build_args_to_child_factories.py b/tests/test_passing_build_args_to_child_factories.py index eab2c936..1ba4ff2a 100644 --- a/tests/test_passing_build_args_to_child_factories.py +++ b/tests/test_passing_build_args_to_child_factories.py @@ -64,7 +64,7 @@ def test_factory_child_model_list() -> None: }, } - person = PersonFactory.build(factory_use_construct=False, **data) # type: ignore[arg-type] + person = PersonFactory.build(factory_use_construct=False, **data) assert person.name == "Jean" assert len(person.pets) == 2 @@ -174,7 +174,7 @@ class D(BaseModel): class DFactory(ModelFactory): __model__ = D - build_result = DFactory.build(factory_use_construct=False, **{"c": {"b": {"a": {"name": "test"}}}}) # type: ignore[arg-type] + build_result = DFactory.build(factory_use_construct=False, **{"c": {"b": {"a": {"name": "test"}}}}) assert build_result assert build_result.c.b.a.name == "test" From b6988bd9ae4ed7a2c4af50947466934e4b967b9a Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Tue, 16 Jan 2024 07:05:11 +0000 Subject: [PATCH 11/11] 'Refactored by Sourcery' --- polyfactory/factories/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 9ab29c0e..41a442d5 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -245,10 +245,7 @@ class Foo(ModelFactory[MyModel]): # <<< MyModel generic_args: Sequence[type[T]] = [ arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) ] - if len(generic_args) != 1: - return None - - return generic_args[0] + return None if len(generic_args) != 1 else generic_args[0] @classmethod def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: