Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model inheritance #346

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ option in the `Config` class.
Now, when `CapitalCity` instances will be persisted to the database, they will
belong in the `city` collection instead of `capital_city`.

!!! warning
Models and Embedded models inheritance is not supported yet.

### Indexes

#### Index definition
Expand Down
5 changes: 3 additions & 2 deletions odmantic/bson.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import decimal
import re
from datetime import datetime, timedelta
from typing import Any, Dict, Pattern
from typing import Any, Dict, Pattern, Type

import bson
import bson.binary
Expand All @@ -10,6 +10,7 @@
import bson.regex
from pydantic.datetime_parse import parse_datetime
from pydantic.main import BaseModel
from pydantic.typing import AnyCallable
from pydantic.validators import (
bytes_validator,
decimal_validator,
Expand Down Expand Up @@ -179,7 +180,7 @@ def __bson__(cls, v: Any) -> bson.decimal128.Decimal128:
return bson.decimal128.Decimal128(v)


BSON_TYPES_ENCODERS = {
BSON_TYPES_ENCODERS: Dict[Type[Any], AnyCallable] = {
bson.ObjectId: str,
bson.decimal128.Decimal128: lambda x: x.to_decimal(), # Convert to regular decimal
bson.regex.Regex: lambda x: x.pattern, # TODO: document no serialization of flags
Expand Down
44 changes: 15 additions & 29 deletions odmantic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pymongo
from pydantic import Extra
from pydantic.main import BaseConfig
from pydantic.typing import AnyCallable

from odmantic.bson import BSON_TYPES_ENCODERS
Expand Down Expand Up @@ -37,7 +36,7 @@ def indexes() -> Iterable[Union[ODMIndex.Index, pymongo.IndexModel]]:

# Inherited from pydantic
title: Optional[str] = None
json_encoders: Dict[Type[Any], AnyCallable] = {}
json_encoders: Dict[Type[Any], AnyCallable] = BSON_TYPES_ENCODERS
schema_extra: Union[Dict[str, Any], "SchemaExtraCallable"] = {}
anystr_strip_whitespace: bool = False
json_loads: Callable[[str], Any] = json.loads
Expand All @@ -49,36 +48,23 @@ def indexes() -> Iterable[Union[ODMIndex.Index, pymongo.IndexModel]]:
ALLOWED_CONFIG_OPTIONS = {name for name in dir(BaseODMConfig) if not is_dunder(name)}


class EnforcedPydanticConfig:
"""Configuration options enforced to work with Models"""
def combine_configs(*configs: Type[Any], **namespace: Any) -> Type[Any]:
# remove redundant bases
bases = list(configs)
while len(bases) > 1 and issubclass(bases[-2], bases[-1]):
del bases[-1]

json_encoders: Dict[Type[Any], AnyCallable] = {}
for config in reversed(bases):
json_encoders.update(getattr(config, "json_encoders", {}))
json_encoders.update(namespace.get("json_encoders", {}))
namespace["json_encoders"] = json_encoders

validate_all = True
validate_assignment = True
return type("Config", tuple(bases), namespace)


def validate_config(
cls_config: Type[BaseODMConfig], cls_name: str
) -> Type[BaseODMConfig]:
def validate_config(config: Type[BaseODMConfig], cls_name: str) -> None:
"""Validate and build the model configuration"""
for name in dir(cls_config):
for name in dir(config):
if not is_dunder(name) and name not in ALLOWED_CONFIG_OPTIONS:
raise ValueError(f"'{cls_name}': 'Config.{name}' is not supported")

if cls_config is BaseODMConfig:
bases = (EnforcedPydanticConfig, BaseODMConfig, BaseConfig)
else:
bases = (
EnforcedPydanticConfig,
cls_config,
BaseODMConfig,
BaseConfig,
) # type:ignore

# Merge json_encoders to preserve bson type encoders
namespace = {
"json_encoders": {
**BSON_TYPES_ENCODERS,
**getattr(cls_config, "json_encoders", {}),
}
}
return type("Config", bases, namespace)
145 changes: 71 additions & 74 deletions odmantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
from pydantic.main import BaseModel
from pydantic.tools import parse_obj_as
from pydantic.typing import is_classvar, resolve_annotations
from pydantic.utils import smart_deepcopy

from odmantic.bson import (
_BSON_SUBSTITUTED_FIELDS,
BaseBSONModel,
ObjectId,
_decimalDecimal,
)
from odmantic.config import BaseODMConfig, validate_config
from odmantic.config import BaseODMConfig, combine_configs, validate_config
from odmantic.exceptions import (
DocumentParsingError,
ErrorList,
Expand All @@ -65,7 +66,6 @@
from odmantic.index import Index, ODMBaseIndex, ODMSingleFieldIndex
from odmantic.reference import ODMReferenceInfo
from odmantic.typing import (
GenericAlias,
Literal,
dataclass_transform,
get_args,
Expand Down Expand Up @@ -180,26 +180,42 @@ def validate_type(type_: Type) -> Type:
if subst_type is not None:
return subst_type

type_origin: Optional[Type] = get_origin(type_)
if type_origin is not None and type_origin is not Literal:
type_args: Tuple[Type, ...] = get_args(type_)
new_arg_types = tuple(validate_type(subtype) for subtype in type_args)
type_ = GenericAlias(type_origin, new_arg_types)
if get_origin(type_) not in (None, Literal):
type_ = type_.copy_with(tuple(map(validate_type, get_args(type_))))
return type_


class BaseModelMetaclass(pydantic.main.ModelMetaclass):
@staticmethod
def __validate_cls_namespace__(name: str, namespace: Dict) -> None: # noqa C901
def __validate_cls_namespace__( # noqa C901
name: str, bases: Tuple[type, ...], namespace: Dict
) -> None:
"""Validate the class name space in place"""
annotations = resolve_annotations(
namespace.get("__annotations__", {}), namespace.get("__module__")
)
config = validate_config(namespace.get("Config", BaseODMConfig), name)
odm_fields: Dict[str, ODMBaseField] = {}
references: List[str] = []
bson_serialized_fields: Set[str] = set()
mutable_fields: Set[str] = set()
base_configs: List[type] = [BaseODMConfig]
for base in reversed(bases):
if issubclass(base, _BaseODMModel) and base not in (Model, EmbeddedModel):
odm_fields.update(smart_deepcopy(base.__odm_fields__))
references.extend(base.__references__)
bson_serialized_fields.update(base.__bson_serialized_fields__)
mutable_fields.update(base.__mutable_fields__)
base_configs.append(base.Config)

# Ensure the namespace config is valid and combine it with the base configs
config = namespace.get("Config")
if config:
validate_config(config, name)
base_configs.append(config)
base_configs.reverse()
config = combine_configs(
*base_configs, validate_all=True, validate_assignment=True
)

# Make sure all fields are defined with type annotation
for field_name, value in namespace.items():
Expand Down Expand Up @@ -347,6 +363,8 @@ def __validate_cls_namespace__(name: str, namespace: Dict) -> None: # noqa C901
if duplicate_key is not None:
raise TypeError(f"Duplicated key_name: {duplicate_key} in {name}")

# Avoid getting the docstrings from the parent classes
namespace.setdefault("__doc__", "")
namespace["__annotations__"] = annotations
namespace["__odm_fields__"] = odm_fields
namespace["__references__"] = tuple(references)
Expand All @@ -369,21 +387,6 @@ def __new__(
"Model",
"EmbeddedModel",
)

if is_custom_cls:
# Handle calls from pydantic.main.create_model (used internally by FastAPI)
patched_bases = []
for b in bases:
if hasattr(b, "__pydantic_model__"):
patched_bases.append(b.__pydantic_model__)
else:
patched_bases.append(b)
bases = tuple(patched_bases)
# Nullify unset docstring (to avoid getting the docstrings from the parent
# classes)
if namespace.get("__doc__", None) is None:
namespace["__doc__"] = ""

cls = super().__new__(mcs, name, bases, namespace, **kwargs)

if is_custom_cls:
Expand Down Expand Up @@ -424,7 +427,7 @@ def __new__( # noqa C901
if namespace.get("__module__") != "odmantic.model" and namespace.get(
"__qualname__"
) not in ("_BaseODMModel", "Model"):
mcs.__validate_cls_namespace__(name, namespace)
mcs.__validate_cls_namespace__(name, bases, namespace)
config: BaseODMConfig = namespace["Config"]
primary_field: Optional[str] = None
odm_fields: Dict[str, ODMBaseField] = namespace["__odm_fields__"]
Expand Down Expand Up @@ -487,7 +490,7 @@ def __new__(
if namespace.get("__module__") != "odmantic.model" and namespace.get(
"__qualname__"
) not in ("_BaseODMModel", "EmbeddedModel"):
mcs.__validate_cls_namespace__(name, namespace)
mcs.__validate_cls_namespace__(name, bases, namespace)
odm_fields: Dict[str, ODMBaseField] = namespace["__odm_fields__"]
for field in odm_fields.values():
if isinstance(field, ODMField) and field.primary_field:
Expand Down Expand Up @@ -691,46 +694,6 @@ def dict( # type: ignore # Missing deprecated/ unsupported parameters
exclude_none=exclude_none,
)

def __doc(
self,
raw_doc: Dict[str, Any],
model: Type["_BaseODMModel"],
include: Optional["AbstractSetIntStr"] = None,
) -> Dict[str, Any]:
doc: Dict[str, Any] = {}
for field_name, field in model.__odm_fields__.items():
if include is not None and field_name not in include:
continue
if isinstance(field, ODMReference):
doc[field.key_name] = raw_doc[field_name][field.model.__primary_field__]
elif isinstance(field, ODMEmbedded):
doc[field.key_name] = self.__doc(raw_doc[field_name], field.model, None)
elif isinstance(field, ODMEmbeddedGeneric):
if field.generic_origin is dict:
doc[field.key_name] = {
item_key: self.__doc(item_value, field.model)
for item_key, item_value in raw_doc[field_name].items()
}
else:
doc[field.key_name] = [
self.__doc(item, field.model) for item in raw_doc[field_name]
]
elif field_name in model.__bson_serialized_fields__:
doc[field.key_name] = model.__fields__[field_name].type_.__bson__(
raw_doc[field_name]
)
else:
doc[field.key_name] = raw_doc[field_name]

if model.Config.extra == "allow":
extras = set(raw_doc.keys()) - set(model.__odm_fields__.keys())
for extra in extras:
value = raw_doc[extra]
subst_type = validate_type(type(value))
bson_serialization_method = getattr(subst_type, "__bson__", lambda x: x)
doc[extra] = bson_serialization_method(raw_doc[extra])
return doc

def doc(self, include: Optional["AbstractSetIntStr"] = None) -> Dict[str, Any]:
"""Generate a document representation of the instance (as a dictionary).

Expand All @@ -741,10 +704,37 @@ def doc(self, include: Optional["AbstractSetIntStr"] = None) -> Dict[str, Any]:
Returns:
the document associated to the instance
"""
raw_doc = self.dict()
doc = self.__doc(raw_doc, type(self), include)
doc: Dict[str, Any] = {}
fields = self.__fields__
bson_serialized_fields = self.__bson_serialized_fields__
for field_name, field in self.__odm_fields__.items():
if include is None or field_name in include:
value = getattr(self, field_name)
if field_name in bson_serialized_fields:
doc_value = fields[field_name].type_.__bson__(value)
elif isinstance(field, ODMReference):
doc_value = getattr(value, value.__primary_field__)
else:
doc_value = self.__doc_value(value)
doc[field.key_name] = doc_value

if self.Config.extra == "allow":
odm_field_names = self.__odm_fields__.keys()
for field_name, value in self:
if field_name not in odm_field_names:
to_bson = getattr(validate_type(type(value)), "__bson__", None)
if to_bson is not None:
value = to_bson(value)
doc[field_name] = value
return doc

def __doc_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {k: self.__doc_value(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return list(map(self.__doc_value, value))
return value.doc() if isinstance(value, _BaseODMModel) else value

@classmethod
def parse_doc(cls: Type[BaseT], raw_doc: Dict) -> BaseT:
"""Parse a BSON document into an instance of the Model
Expand Down Expand Up @@ -817,13 +807,10 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
)
obj[field_name] = value
elif isinstance(field, ODMEmbeddedGeneric):
value = Undefined
raw_value = raw_doc.get(field.key_name, Undefined)
if raw_value is not Undefined:
if isinstance(raw_value, list) and (
field.generic_origin is list
or field.generic_origin is tuple
or field.generic_origin is set
if field.generic_origin in (list, tuple, set) and isinstance(
raw_value, list
):
value = []
for i, item in enumerate(raw_value):
Expand All @@ -835,7 +822,7 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
else:
value.append(item)
obj[field_name] = value
elif isinstance(raw_value, dict) and field.generic_origin is dict:
elif field.generic_origin is dict and isinstance(raw_value, dict):
value = {}
for item_key, item_value in raw_value.items():
sub_errors, item_value = field.model._parse_doc_to_obj(
Expand All @@ -847,6 +834,15 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
else:
value[item_key] = item_value
obj[field_name] = value
elif field.generic_origin is Union: # actually Optional
if raw_value is not None:
sub_errors, value = field.model._parse_doc_to_obj(
raw_value, base_loc=base_loc + (field_name,)
)
errors.extend(sub_errors)
obj[field_name] = value
else:
obj[field_name] = None
else:
errors.append(
ErrorWrapper(
Expand All @@ -855,6 +851,7 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
)
)
else:
value = Undefined
if not field.is_required_in_doc():
value = field.get_default_importing_value()
if value is Undefined:
Expand Down
8 changes: 1 addition & 7 deletions odmantic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
# FIXME: add this back to coverage once 3.11 is released
from typing import dataclass_transform # noqa: F401 # pragma: no cover

HAS_GENERIC_ALIAS_BUILTIN = sys.version_info[:3] >= (3, 9, 0) # PEP 560
if HAS_GENERIC_ALIAS_BUILTIN:
from typing import GenericAlias # type: ignore
else:
from typing import _GenericAlias as GenericAlias # type: ignore # noqa: F401


# Taken from https://github.com/pydantic/pydantic/pull/2392
# Reimplemented here to avoid a dependency deprecation on pydantic1.7
Expand All @@ -32,7 +26,7 @@ def lenient_issubclass(
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if isinstance(cls, GenericAlias):
if hasattr(cls, "__origin__"):
return False
raise # pragma: no cover

Expand Down
Loading