diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 09058037..20ad9616 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1071,6 +1071,25 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) + return cls.post_generate(result) + + @classmethod + def post_build(cls, model: T) -> T: + """Post-create hook. Helpful for building additional database associations or running logic which requires the + fully-created model. + + :param model: The created model instance. + :returns: The (optionally) mutated model. + """ + return model + + @classmethod + def post_generate(cls, result: dict[str, Any]) -> dict[str, Any]: + """Post-generate hook. Helpful for mutating or adding additional fields right before model creation. + + :param result: The kwargs that will be passed to the model. + :returns: The (optionally) mutated kwargs. + """ return result @classmethod @@ -1131,7 +1150,11 @@ def build(cls, *_: Any, **kwargs: Any) -> T: :returns: An instance of type T. """ - return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + created_model = cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + + cls.post_build(created_model) + + return created_model @classmethod def batch(cls, size: int, **kwargs: Any) -> list[T]: @@ -1156,6 +1179,7 @@ def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: """ for data in cls.process_kwargs_coverage(**kwargs): instance = cls.__model__(**data) + cls.post_build(instance) yield cast("T", instance) @classmethod diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index e0b9a757..6871eaac 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -147,7 +147,7 @@ def from_field_info( min_collection_length: int | None = None, max_collection_length: int | None = None, ) -> PydanticFieldMeta: - """Create an instance from a pydantic field info. + """Create an instance from a pydantic field info. Used by `get_model_fields` to generate field list for a model. :param field_name: The name of the field. :param field_info: A pydantic FieldInfo instance. @@ -517,7 +517,11 @@ def build( processed_kwargs = cls.process_kwargs(**kwargs) - return cls._create_model(kwargs["_build_context"], **processed_kwargs) + created_model = cls._create_model(kwargs["_build_context"], **processed_kwargs) + + cls.post_build(created_model) + + return created_model @classmethod def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildContext | None) -> PydanticBuildContext: @@ -568,7 +572,11 @@ def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Ite ) for data in cls.process_kwargs_coverage(**kwargs): - yield cls._create_model(_build_context=kwargs["_build_context"], **data) + created_model = cls._create_model(_build_context=kwargs["_build_context"], **data) + + cls.post_build(created_model) + + yield created_model @classmethod def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index e07aa9f6..840be428 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -166,6 +166,25 @@ def caption(cls, is_long: bool) -> str: assert result.caption == "just this" +def test_post_build_classmethod() -> None: + @dataclass + class Model: + i: int + j: int + + class Factory(DataclassFactory[Model]): + __model__ = Model + + @classmethod + def post_build(cls, model: Model) -> Model: + model.i = model.j + 10 + return model + + result = Factory.build() + + assert result.i == result.j + 10 + + @pytest.mark.parametrize( "factory_field", [ diff --git a/tests/test_pydantic_factory.py b/tests/test_pydantic_factory.py index 5612cf55..cc6aef30 100644 --- a/tests/test_pydantic_factory.py +++ b/tests/test_pydantic_factory.py @@ -1100,3 +1100,20 @@ class PaymentFactory(ModelFactory[Payment]): instance = PaymentFactory.build(currency="DKK") assert instance.currency == "DKK" + + +def test_post_build_classmethod() -> None: + class Model(BaseModel): + i: int + j: int + + class Factory(ModelFactory[Model]): + __model__ = Model + + @classmethod + def post_build(cls, model: Model) -> Model: + model.i = model.j + 10 + return model + + result = Factory.build() + assert result.i == result.j + 10 diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index d3861ff8..e77c4ffb 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -202,6 +202,26 @@ def i(cls, j: int) -> int: assert results[0].i == results[0].j + 10 +def test_coverage_post_build() -> None: + @dataclass + class Model: + i: int + j: int + + class Factory(DataclassFactory[Model]): + __model__ = Model + + @classmethod + def post_build(cls, model: Model) -> Model: + model.i = model.j + 10 + return model + + results = list(Factory.coverage()) + assert len(results) == 1 + + assert results[0].i == results[0].j + 10 + + class CustomInt: def __init__(self, value: int) -> None: self.value = value