Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic V2 compatibility #250

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
lxml=4.9.2
lxml==4.9.2
requests==2.28.1
requests-cache==0.9.7
pandas==1.5.2
pydantic==2.5.2
pandas==2.1.3
sphinx==5.3
ipython==8.7.0

3 changes: 2 additions & 1 deletion pandasdmx/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def _request_from_args(self, kwargs):
url_parts.append(key)

# Assemble final URL
url = "/".join(filter(None, url_parts))
url_parts = [str(x) for x in filter(None, url_parts)]
url = "/".join(url_parts)

# Parameters: set 'references' to sensible defaults
if "references" not in parameters:
Expand Down
3 changes: 2 additions & 1 deletion pandasdmx/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def add(self, obj: model.IdentifiableArtefact):
for field, field_info in direct_fields(self.__class__).items():
# NB for some reason mypy complains here, but not in __contains__(), below
if isinstance(
obj, get_args(field_info.outer_type_)[1], # type: ignore [attr-defined]
obj,
get_args(field_info.outer_type_)[1], # type: ignore [attr-defined]
):
getattr(self, field)[obj.id] = obj
return
Expand Down
50 changes: 34 additions & 16 deletions pandasdmx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from operator import attrgetter, itemgetter
from typing import (
Any,
ClassVar,
Dict,
Generator,
Generic,
Expand All @@ -49,6 +50,10 @@
)
from warnings import warn

from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema
from pydantic_core import core_schema

from pandasdmx.util import (
BaseModel,
DictLike,
Expand Down Expand Up @@ -186,18 +191,28 @@ def __eq__(self, other):
return NotImplemented

@classmethod
def __get_validators__(cls):
yield cls.__validate
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.chain_schema(
[
core_schema.with_info_plain_validator_function(
function=cls.__validate,
),
]
)

@classmethod
def __validate(cls, value, values, config, field):
# Any value that the constructor can handle can be assigned
def __validate(cls, value, info):
# Any value except None that the constructor can handle can be assigned
if value == None:
raise ValueError
if not isinstance(value, InternationalString):
value = InternationalString(value)

try:
# Update existing value
existing = values[field.name]
existing = info.data[info.field_name]
existing.localizations.update(value.localizations)
return existing
except KeyError:
Expand Down Expand Up @@ -602,7 +617,7 @@ class ItemScheme(MaintainableArtefact, Generic[IT]):
# TODO add delete()
# TODO add sorting capability; perhaps sort when new items are inserted

is_partial: Optional[bool]
is_partial: Optional[bool] = None

#: Members of the ItemScheme. Both ItemScheme and Item are abstract classes.
#: Concrete classes are paired: for example, a :class:`.Codelist` contains
Expand Down Expand Up @@ -734,7 +749,7 @@ def setdefault(self, obj=None, **kwargs) -> IT:
kwargs["parent"] = self[parent]

# Instantiate an object of the correct class
obj = self._Item(**kwargs)
obj = self.__class__._Item.get_default()(**kwargs)

try:
# Add the object to the ItemScheme
Expand All @@ -745,9 +760,6 @@ def setdefault(self, obj=None, **kwargs) -> IT:
return obj


Item.update_forward_refs()


# §3.6: Structure


Expand Down Expand Up @@ -863,7 +875,7 @@ class ComponentList(IdentifiableArtefact, Generic[CT]):
#:
components: List[CT] = []
#:
auto_order = 1
auto_order: ClassVar[int] = 1

# The default type of the Components in the ComponentList. See comment on
# ItemScheme._Item
Expand Down Expand Up @@ -916,7 +928,7 @@ def getdefault(self, id, cls=None, **kwargs) -> CT:
# order property
try:
component.order = self.auto_order
self.auto_order += 1
self.__class__.auto_order += 1
except ValueError:
pass

Expand Down Expand Up @@ -1040,7 +1052,7 @@ class Agency(Organisation):
# Update forward references to 'Agency'
for cls in list(locals().values()):
if isclass(cls) and issubclass(cls, MaintainableArtefact):
cls.update_forward_refs()
cls.model_rebuild()


class OrganisationScheme:
Expand Down Expand Up @@ -1464,8 +1476,8 @@ def assign_order(self):
pass


DimensionRelationship.update_forward_refs()
GroupRelationship.update_forward_refs()
DimensionRelationship.model_rebuild()
GroupRelationship.model_rebuild()


class _NullConstraintClass:
Expand Down Expand Up @@ -1789,7 +1801,7 @@ class KeyValue(BaseModel):
#: The actual value.
value: Any
#:
value_for: Optional[Dimension] = None
value_for: Optional[DimensionComponent] = None

def __init__(self, *args, **kwargs):
args, kwargs = value_for_dsd_ref("dimension", args, kwargs)
Expand All @@ -1810,6 +1822,9 @@ def __eq__(self, other):
else:
return self.value == other

def __lt__(self, other):
return self.value < other.value

def __str__(self):
return "{0.id}={0.value}".format(self)

Expand Down Expand Up @@ -2302,6 +2317,9 @@ class ProvisionAgreement(MaintainableArtefact, ConstrainableArtefact):
data_provider: Optional[DataProvider] = None


Item.model_rebuild()


#: The SDMX-IM defines 'packages'; these are used in URNs.
PACKAGE = dict()

Expand Down
9 changes: 5 additions & 4 deletions pandasdmx/reader/sdmxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class _NoText:
# Sentinel value for XML elements with no text; used to distinguish from "" and None
NoText = _NoText()


class Reference:
"""Temporary class for references.

Expand Down Expand Up @@ -205,7 +206,7 @@ def __str__(self): # pragma: no cover

class XSDResolver(etree.Resolver):
"""
Resolve XSD imports to locate them within <user_data_dir>/pandaSDMX/sdmx_2_1.
Resolve XSD imports to locate them within <user_data_dir>/pandaSDMX/sdmx_2_1.
"""

def __init__(self, *args, schema_dir=None, **kwargs):
Expand Down Expand Up @@ -249,7 +250,7 @@ def validate_message(msg, schema_dir=None):
must be installed first. See the docs on
:func:`pandasdmx.api.install_schemas` and
:meth:`pandasdmx.api.Request.validate`.

Returns whatever lxml.etree.XMLSchema.validate returns
"""
msg_doc = etree.parse(msg)
Expand Down Expand Up @@ -829,7 +830,7 @@ def _ref(reader, elem):

@end("com:Annotation")
def _a(reader, elem):
url=reader.pop_single("AnnotationURL")
url = reader.pop_single("AnnotationURL")
args = dict(
title=reader.pop_single("AnnotationTitle"),
type=reader.pop_single("AnnotationType"),
Expand Down Expand Up @@ -912,7 +913,7 @@ def _itemscheme(reader, elem):
is_ = reader.maintainable(cls, elem)

# Iterate over all Item objects *and* their children
iter_all = chain(*[iter(item) for item in reader.pop_all(cls._Item)])
iter_all = chain(*[iter(item) for item in reader.pop_all(cls._Item.default)])

# Set of objects already added to `items`
seen = dict()
Expand Down
6 changes: 3 additions & 3 deletions pandasdmx/source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Source(BaseModel):
id: str
#: Optional API IDTakes precedence over id when URL is constructed
# Useful if a provider offers several APIs
api_id: Optional[str]
api_id: Optional[str] = None

#: Base URL for queries
url: Optional[HttpUrl]
Expand All @@ -40,7 +40,7 @@ class Source(BaseModel):
name: str

#: documentation URL of the data source
documentation: Optional[HttpUrl]
documentation: Optional[HttpUrl] = None

headers: Dict[str, Any] = {}

Expand Down Expand Up @@ -124,7 +124,7 @@ def modify_request_args(self, kwargs):

@validator("id")
def _validate_id(cls, value):
assert getattr(cls, "_id", value) == value
assert cls.__dict__.get("_id", value) == value
return value

@validator("data_content_type", pre=True)
Expand Down
3 changes: 2 additions & 1 deletion pandasdmx/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pydantic
import pytest
from pytest import raises
import re

from pandasdmx import model
from pandasdmx.model import (
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_internationalstring():
assert str(i2.name) == "European Central Bank"

# Creating with name=None raises an exception…
with raises(pydantic.ValidationError, match="none is not an allowed value"):
with raises(pydantic.ValidationError, match=re.compile(r"name\n.*input_value=None")):
Item(id="ECB", name=None)

# …giving empty dict is equivalent to giving nothing
Expand Down
31 changes: 17 additions & 14 deletions pandasdmx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import pydantic
import requests
from pydantic import Field, ValidationError, validator
from pydantic.class_validators import make_generic_validator
from pydantic.typing import get_origin # type: ignore [attr-defined]
from pydantic import validator
from typing import get_origin # type: ignore [attr-defined]

try:
import requests_cache
Expand Down Expand Up @@ -162,7 +162,7 @@ class BaseModel(pydantic.BaseModel):
"""Common settings for :class:`pydantic.BaseModel` in :mod:`pandasdmx`."""

class Config:
copy_on_model_validation = 'none'
copy_on_model_validation = "none"
validate_assignment = True


Expand Down Expand Up @@ -245,22 +245,25 @@ def __get_validators__(cls):
yield cls._validate_whole

@classmethod
def _validate_whole(cls, v, field: pydantic.fields.ModelField):
def _validate_whole(cls, v, field: str):
"""Validate `v` as an entire DictLike object."""
# Convert anything that can be converted to a dict(). pydantic internals catch
# most other invalid types, e.g. set(); no need to handle them here.
if cls == DictLike:
return v

result = cls(v)

# Reference to the pydantic.field.ModelField for the entries
result.__field = field
result.__field = cls.model_fields[field]

return result

def _validate_entry(self, key, value):
"""Validate one `key`/`value` pair."""
try:
# Use pydantic's validation machinery
v, error = self.__field._validate_mapping_like(
v, error = self.__class.model_fields[self.__field]._validate_mapping_like(
((key, value),), values={}, loc=(), cls=None
)
except AttributeError:
Expand Down Expand Up @@ -332,11 +335,11 @@ def validate_dictlike(cls):
lambda item: get_origin(item[1]) is DictLike, cls.__annotations__.items()
):
# Add the validator(s)
field = cls.__fields__[name]
field.post_validators = field.post_validators or []
field.post_validators.extend(
make_generic_validator(v) for v in DictLike.__get_validators__()
)
for v in DictLike.__get_validators__():

@validator(name, allow_reuse=True, pre=False)
def _validator(cls, value):
return v(cls, value, name, None)

return cls

Expand Down Expand Up @@ -388,16 +391,16 @@ def parse_content_type(value: str) -> Tuple[str, Dict[str, Any]]:


@lru_cache()
def direct_fields(cls) -> Mapping[str, pydantic.fields.ModelField]:
def direct_fields(cls) -> Mapping[str, str]:
"""Return the :mod:`pydantic` fields defined on `obj` or its class.

This is like the ``__fields__`` attribute, but excludes the fields defined on any
parent class(es).
"""
return {
name: info
for name, info in cls.__fields__.items()
if name not in set(cls.mro()[1].__fields__.keys())
for name, info in cls.model_fields.items()
if name not in set(cls.mro()[1].model_fields.keys())
}


Expand Down
Loading