Skip to content

Commit c8d6f08

Browse files
committed
feat: add post_create and post_generate hooks
1 parent 396b555 commit c8d6f08

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

polyfactory/factories/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,24 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
10671067
for field_name, post_generator in generate_post.items():
10681068
result[field_name] = post_generator.to_value(field_name, result)
10691069

1070+
return cls.post_generate(result)
1071+
1072+
@classmethod
1073+
def post_create(cls, model: T) -> None:
1074+
"""Post-create hook. Helpful for building additional database associations or running logic which requires the
1075+
fully-created model.
1076+
1077+
:param model: The created model instance.
1078+
"""
1079+
pass
1080+
1081+
@classmethod
1082+
def post_generate(cls, result: dict[str, Any]) -> dict[str, Any]:
1083+
"""Post-generate hook. Helpful for mutating or adding additional fields right before model creation.
1084+
1085+
:param result: The kwargs that will be passed to the model.
1086+
:returns: The (optionally) mutated kwargs.
1087+
"""
10701088
return result
10711089

10721090
@classmethod
@@ -1127,7 +1145,11 @@ def build(cls, *_: Any, **kwargs: Any) -> T:
11271145
:returns: An instance of type T.
11281146
11291147
"""
1130-
return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
1148+
created_model = cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
1149+
1150+
cls.post_create(created_model)
1151+
1152+
return created_model
11311153

11321154
@classmethod
11331155
def batch(cls, size: int, **kwargs: Any) -> list[T]:

polyfactory/factories/pydantic_factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def from_field_info(
147147
min_collection_length: int | None = None,
148148
max_collection_length: int | None = None,
149149
) -> PydanticFieldMeta:
150-
"""Create an instance from a pydantic field info.
150+
"""Create an instance from a pydantic field info. Used by `get_model_fields` to generate field list for a model.
151151
152152
:param field_name: The name of the field.
153153
:param field_info: A pydantic FieldInfo instance.
@@ -473,7 +473,11 @@ def build(
473473

474474
processed_kwargs = cls.process_kwargs(**kwargs)
475475

476-
return cls._create_model(kwargs["_build_context"], **processed_kwargs)
476+
created_model = cls._create_model(kwargs["_build_context"], **processed_kwargs)
477+
478+
cls.post_create(created_model)
479+
480+
return created_model
477481

478482
@classmethod
479483
def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildContext | None) -> PydanticBuildContext:

0 commit comments

Comments
 (0)