Skip to content

Commit ddb77ff

Browse files
committed
fix: introduce pydantic v1/v2 code to hanble v1 dataclasses
1 parent e9daae0 commit ddb77ff

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

polyfactory/factories/pydantic_factory.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@
6161
# is installed.
6262
from pydantic import PyObject
6363

64-
# prevent unbound variable warnings
64+
# Prevent unbound variable warnings
6565
BaseModelV2 = BaseModelV1
6666
UndefinedV2 = Undefined
67+
68+
if TYPE_CHECKING:
69+
from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage]
70+
71+
# Prevent unbound variable warnings
72+
PydanticDataclassV2 = PydanticDataclassV1
6773
except ImportError:
6874
# pydantic v2
6975

@@ -92,6 +98,8 @@
9298
from pydantic.v1.color import Color # type: ignore[assignment]
9399
from pydantic.v1.fields import DeferredType, ModelField, Undefined
94100

101+
if TYPE_CHECKING:
102+
from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage]
95103

96104
if TYPE_CHECKING:
97105
from collections import abc
@@ -100,7 +108,6 @@
100108

101109
from typing_extensions import NotRequired, TypeGuard
102110

103-
from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage]
104111

105112
ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
106113
T = TypeVar("T")
@@ -627,8 +634,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign
627634
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)
628635

629636

630-
def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]:
631-
# This method is available in the `pydantic.dataclasses` module for python >= 3.9
637+
def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]:
638+
return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__
639+
640+
641+
def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]:
632642
return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__
633643

634644

@@ -639,27 +649,37 @@ class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var]
639649

640650
@classmethod
641651
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
642-
return is_pydantic_dataclass(value)
652+
return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value)
643653

644654
@classmethod
645655
def get_model_fields(cls) -> list[FieldMeta]:
646-
if not is_pydantic_dataclass(cls.__model__):
656+
if _is_pydantic_v1_dataclass(cls.__model__):
657+
pydantic_model = cls.__model__.__pydantic_model__
658+
cls._fields_metadata = [
659+
PydanticFieldMeta.from_model_field(
660+
field,
661+
use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
662+
random=cls.__random__,
663+
)
664+
for field in pydantic_model.__fields__.values()
665+
]
666+
elif _is_pydantic_v2_dataclass(cls.__model__):
667+
pydantic_fields = cls.__model__.__pydantic_fields__
668+
pydantic_config = cls.__model__.__pydantic_config__
669+
cls._fields_metadata = [
670+
PydanticFieldMeta.from_field_info(
671+
field_info=field_info,
672+
field_name=field_name,
673+
random=cls.__random__,
674+
use_alias=not pydantic_config.get(
675+
"populate_by_name",
676+
False,
677+
),
678+
)
679+
for field_name, field_info in pydantic_fields.items()
680+
]
681+
else:
647682
# This should be unreachable
648683
return []
649684

650-
pydantic_fields = cls.__model__.__pydantic_fields__
651-
pydantic_config = cls.__model__.__pydantic_config__
652-
cls._fields_metadata = [
653-
PydanticFieldMeta.from_field_info(
654-
field_info=field_info,
655-
field_name=field_name,
656-
random=cls.__random__,
657-
use_alias=not pydantic_config.get(
658-
"populate_by_name",
659-
False,
660-
),
661-
)
662-
for field_name, field_info in pydantic_fields.items()
663-
]
664-
665685
return cls._fields_metadata

0 commit comments

Comments
 (0)