Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pydantic v2 migration #3007

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
numpy~=1.16
pandas>=1.0,<3
pydantic~=1.7
pydantic~=2.1
pydantic-settings>=2.0.3
tqdm~=4.23
toolz~=0.10

Expand Down
14 changes: 3 additions & 11 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Type, TypeVar

import numpy as np
from pydantic import BaseConfig, BaseModel, ValidationError, create_model
from pydantic import ConfigDict, BaseModel, ValidationError, create_model

from gluonts.core import fqname_for
from gluonts.exceptions import GluonTSHyperparametersError
Expand Down Expand Up @@ -217,15 +217,7 @@ class BaseValidatedInitializerModel(BaseModel):
Decorates an initializer methods with argument validation logic.
"""

class Config(BaseConfig):
"""
`Config <https://pydantic-docs.helpmanual.io/#model-config>`_ for the
Pydantic model inherited by all :func:`validated` initializers.

Allows the use of arbitrary type annotations in initializer parameters.
"""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


def validated(base_model=None):
Expand Down Expand Up @@ -306,7 +298,7 @@ def validator(init):
if base_model is None:
PydanticModel = create_model(
f"{init_clsnme}Model",
__config__=BaseValidatedInitializerModel.Config,
__config__=BaseValidatedInitializerModel.model_config,
**init_fields,
)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/core/serde/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _call(self, env):
def dataclass(
cls=None,
*,
init=True,
init=False,
repr=True,
eq=True,
order=False,
Expand All @@ -131,7 +131,7 @@ def dataclass(
"""

# assert frozen
assert init
assert init is False

if cls is None:
return _dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _set_(self, dct, key, value):
# assignment: `settings.foo = {"b": 1}` should only set `b`
# Thus we check whether we are dealing with a pydantic model and if
# we are also assigning a `dict`:
type_ = model.__fields__[key].type_
type_ = model.model_fields[key].annotation

if issubclass(type_, pydantic.BaseModel) and isinstance(
value, dict
Expand Down
9 changes: 5 additions & 4 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from . import Dataset, DatasetCollection, DataEntry, DataBatch # noqa
from . import jsonl, DatasetWriter
from pydantic import ConfigDict


arrow: Optional[ModuleType]
Expand Down Expand Up @@ -63,8 +64,7 @@ class MetaData(pydantic.BaseModel):

prediction_length: Optional[int] = None

class Config(pydantic.BaseConfig):
allow_population_by_field_name = True
model_config = ConfigDict(allow_population_by_field_name=True)


class SourceContext(NamedTuple):
Expand Down Expand Up @@ -277,8 +277,9 @@ class ProcessStartField(pydantic.BaseModel):
Frequency to use. This must be a valid Pandas frequency string.
"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
populate_by_name=True, arbitrary_types_allowed=True
)

freq: Union[str, pd.DateOffset]
use_timestamp: bool = False
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from typing import Any

from pydantic.error_wrappers import ValidationError, display_errors
from pydantic import ValidationError
from pydantic.v1.error_wrappers import display_errors


class GluonTSException(Exception):
Expand Down
7 changes: 4 additions & 3 deletions src/gluonts/ext/rotbaum/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Union, Optional

import numpy as np
from pydantic import BaseModel, root_validator
from pydantic import model_validator, BaseModel


class FeatureImportanceResult(BaseModel):
Expand All @@ -25,7 +25,8 @@ class FeatureImportanceResult(BaseModel):
feat_dynamic_real: List[Union[List[float], float]]
feat_dynamic_cat: List[Union[List[float], float]]

@root_validator()
@model_validator(mode="before")
@classmethod
def check_shape(cls, values):
"""
Validate the second dimension is the same for 2d results and all fields share the same dimensionality
Expand Down Expand Up @@ -60,5 +61,5 @@ def mean(self, axis=None) -> "FeatureImportanceResult":


class ExplanationResult(BaseModel):
time_quantile_aggregated_result: Optional[FeatureImportanceResult]
time_quantile_aggregated_result: Optional[FeatureImportanceResult] = None
quantile_aggregated_result: FeatureImportanceResult
5 changes: 4 additions & 1 deletion src/gluonts/meta/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_version_and_cmdclass(version_file):
)
"""

import os
import subprocess
from pathlib import Path

Expand Down Expand Up @@ -250,4 +251,6 @@ def make_release_tree(self, base_dir, files):
return {"sdist": sdist, "build_py": build_py}


__version__ = get_version(fallback="0.0.0")
__version__ = get_version(
fallback=os.environ.get("GLUONTS_FALLBACK_VERSION", "0.0.0")
)
12 changes: 6 additions & 6 deletions src/gluonts/shell/sagemaker/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

from pydantic import BaseModel
from pydantic import BaseModel, RootModel

from .dyn import install_and_restart
from .params import decode_sagemaker_parameters
Expand All @@ -30,17 +30,17 @@ class DataConfig(BaseModel):
RecordWrapperType: Optional[str] = None


class InpuDataConfig(BaseModel):
__root__: Dict[str, DataConfig]
class InpuDataConfig(RootModel):
root: Dict[str, DataConfig]

def __getitem__(self, item):
return self.__root__[item]
return self.root[item]

def channels(self):
return self.__root__
return self.root

def channel_names(self):
return list(self.__root__.keys())
return list(self.root.keys())


class TrainPaths:
Expand Down
6 changes: 2 additions & 4 deletions src/gluonts/shell/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional, Type, Union

from flask import Flask
from pydantic import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
kashif marked this conversation as resolved.
Show resolved Hide resolved

import gluonts
from gluonts.core import fqname_for
Expand All @@ -39,9 +39,7 @@


class Settings(BaseSettings):
# see: https://pydantic-docs.helpmanual.io/#settings
class Config:
env_prefix = ""
model_config = SettingsConfigDict(env_prefix="")

model_server_workers: Optional[int] = None
max_content_length: int = 6 * MB
Expand Down
8 changes: 2 additions & 6 deletions src/gluonts/shell/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing_extensions import Literal

from flask import Flask, Response, jsonify, request
from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field

from gluonts.dataset.common import ListDataset
from gluonts.dataset.jsonl import encode_json
Expand All @@ -42,11 +42,7 @@ class ForecastConfig(BaseModel):
output_types: Set[OutputType] = {"quantiles", "mean"}
# FIXME: validate list elements
quantiles: List[str] = ["0.1", "0.5", "0.9"]

class Config:
allow_population_by_field_name = True
# store additional fields
extra = "allow"
model_config = ConfigDict(populate_by_name=True, extra="allow")

def as_json_dict(self, forecast: Forecast) -> dict:
result = {}
Expand Down
6 changes: 2 additions & 4 deletions src/gluonts/transform/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Tuple

import numpy as np
from pydantic import BaseModel
from pydantic import ConfigDict, BaseModel

from gluonts.dataset.stat import ScaleHistogram

Expand All @@ -31,9 +31,7 @@ class InstanceSampler(BaseModel):
axis: int = -1
min_past: int = 0
min_future: int = 0

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def _get_bounds(self, ts: np.ndarray) -> Tuple[int, int]:
return (
Expand Down
2 changes: 1 addition & 1 deletion test/dataset/artificial/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_recipe_dataset(recipe) -> None:
freq="D",
feat_static_real=[BasicFeatureInfo(name="feat_static_real_000")],
feat_static_cat=[
CategoricalFeatureInfo(name="foo", cardinality=10)
CategoricalFeatureInfo(name="foo", cardinality="10")
],
feat_dynamic_real=[BasicFeatureInfo(name="binary_causal")],
),
Expand Down
Loading