Skip to content

Commit

Permalink
Merge pull request #88 from slaclab/pydantic_v2
Browse files Browse the repository at this point in the history
Pydantic v2
  • Loading branch information
roussel-ryan committed Oct 25, 2023
2 parents ca5e467 + 0937301 commit 73cb485
Show file tree
Hide file tree
Showing 25 changed files with 536 additions and 519 deletions.
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ requirements:
- python
- setuptools
- pip
- pydantic==1.10.9
- pydantic>2.3
run:
- python
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml

Expand Down
2 changes: 1 addition & 1 deletion dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml
- tensorflow
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml
169 changes: 126 additions & 43 deletions lume_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import yaml
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Union
from typing import Any, Callable, Union, TextIO
from types import FunctionType, MethodType

import numpy as np
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, field_validator, SerializeAsAny

from lume_model.variables import (
InputVariable,
OutputVariable,
OutputVariable, ScalarInputVariable, ScalarOutputVariable,
)
from lume_model.utils import (
try_import_module,
Expand All @@ -23,7 +23,6 @@

logger = logging.getLogger(__name__)


JSON_ENCODERS = {
# function/method type distinguished for class members and not recognized as callables
FunctionType: lambda x: f"{x.__module__}.{x.__qualname__}",
Expand Down Expand Up @@ -96,7 +95,7 @@ def process_keras_model(


def recursive_serialize(
v,
v: dict[str, Any],
base_key: str = "",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
Expand All @@ -121,11 +120,13 @@ def recursive_serialize(
if isinstance(value, dict):
v[key] = recursive_serialize(value, key)
elif torch is not None and isinstance(value, torch.nn.Module):
v[key] = process_torch_module(value, base_key, key, file_prefix, save_models)
v[key] = process_torch_module(value, base_key, key, file_prefix,
save_models)
elif isinstance(value, list) and torch is not None and any(
isinstance(ele, torch.nn.Module) for ele in value):
v[key] = [
process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix, save_models)
process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix,
save_models)
for i in range(len(value))
]
elif keras is not None and isinstance(value, keras.Model):
Expand Down Expand Up @@ -164,7 +165,6 @@ def recursive_deserialize(v):
def json_dumps(
v,
*,
default,
base_key="",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
Expand All @@ -173,16 +173,15 @@ def json_dumps(
Args:
v: Object to dump.
default: Default for json.dumps().
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.
Returns:
JSON formatted string.
"""
v = recursive_serialize(v, base_key, file_prefix, save_models)
v = json.dumps(v, default=default)
v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models)
v = json.dumps(v)
return v


Expand Down Expand Up @@ -232,7 +231,8 @@ def model_kwargs_from_dict(config: dict) -> dict:
"""
config = deserialize_variables(config)
if all(key in config.keys() for key in ["input_variables", "output_variables"]):
config["input_variables"], config["output_variables"] = variables_from_dict(config)
config["input_variables"], config["output_variables"] = variables_from_dict(
config)
_ = config.pop("model_class", None)
return config

Expand All @@ -247,34 +247,66 @@ class LUMEBaseModel(BaseModel, ABC):
input_variables: List defining the input variables and their order.
output_variables: List defining the output variables and their order.
"""
input_variables: list[InputVariable]
output_variables: list[OutputVariable]
input_variables: list[SerializeAsAny[InputVariable]]
output_variables: list[SerializeAsAny[OutputVariable]]

class Config:
extra = "allow"
json_dumps = json_dumps
json_loads = json_loads
validate_assignment = True
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

def __init__(
self,
config: Union[dict, str] = None,
**kwargs,
):
@field_validator("input_variables", mode="before")
def validate_input_variables(cls, value):
new_value = []
if isinstance(value, dict):
for name, val in value.items():
if isinstance(val, dict):
if val["variable_type"] == "scalar":
new_value.append(ScalarInputVariable(name=name, **val))
elif isinstance(val, InputVariable):
new_value.append(val)
else:
raise TypeError(f"type {type(val)} not supported")
elif isinstance(value, list):
new_value = value

return new_value

@field_validator("output_variables", mode="before")
def validate_output_variables(cls, value):
new_value = []
if isinstance(value, dict):
for name, val in value.items():
if isinstance(val, dict):
if val["variable_type"] == "scalar":
new_value.append(ScalarOutputVariable(name=name, **val))
elif isinstance(val, OutputVariable):
new_value.append(val)
else:
raise TypeError(f"type {type(val)} not supported")
elif isinstance(value, list):
new_value = value

return new_value

def __init__(self, *args, **kwargs):
"""Initializes LUMEBaseModel.
Args:
config: Model configuration as dictionary, YAML or JSON formatted string or file path. This overrides
all other arguments.
*args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
formatted string or file path.
**kwargs: See class attributes.
"""
if config is not None:
self.__init__(**parse_config(config))
if len(args) == 1:
if len(kwargs) > 0:
raise ValueError("Cannot specify YAML string and keyword arguments for LUMEBaseModel init.")
super().__init__(**parse_config(args[0]))
elif len(args) > 1:
raise ValueError(
"Arguments to LUMEBaseModel must be either a single YAML string "
"or keyword arguments passed directly to pydantic."
)
else:
super().__init__(**kwargs)

@validator("input_variables", "output_variables")
@field_validator("input_variables", "output_variables")
def unique_variable_names(cls, value):
verify_unique_variable_names(value)
return value
Expand All @@ -291,29 +323,80 @@ def output_names(self) -> list[str]:
def evaluate(self, input_dict: dict[str, Any]) -> dict[str, Any]:
pass

def to_json(self, **kwargs) -> str:
return json_dumps(self, **kwargs)

def dict(self, **kwargs) -> dict[str, Any]:
config = super().model_dump(**kwargs)
return {"model_class": self.__class__.__name__} | config

def json(self, **kwargs) -> str:
result = self.to_json(**kwargs)
config = json.loads(result)
config = {"model_class": self.__class__.__name__} | config

return json.dumps(config)

def yaml(
self,
file: Union[str, os.PathLike] = None,
save_models: bool = True,
base_key: str = "",
file_prefix: str = "",
save_models: bool = False,
) -> str:
"""Returns and optionally saves YAML formatted string defining the model.
"""Serializes the object and returns a YAML formatted string defining the model.
Args:
file: If not None, the YAML formatted string is saved to given file path.
save_models: Determines whether models are saved to file.
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.
Returns:
YAML formatted string defining the model.
"""
file_prefix = ""
if file is not None:
file_prefix = os.path.splitext(file)[0]
config = json.loads(self.json(base_key=base_key, file_prefix=file_prefix, save_models=save_models))
s = yaml.dump({"model_class": self.__class__.__name__} | config,
output = json.loads(
self.to_json(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
)
)
s = yaml.dump({"model_class": self.__class__.__name__} | output,
default_flow_style=None, sort_keys=False)
if file is not None:
with open(file, "w") as f:
f.write(s)
return s

def dump(
self,
file: Union[str, os.PathLike],
base_key: str = "",
save_models: bool = True,
):
"""Returns and optionally saves YAML formatted string defining the model.
Args:
file: File path to which the YAML formatted string and corresponding files are saved.
base_key: Base key for serialization.
save_models: Determines whether models are saved to file.
"""
file_prefix = os.path.splitext(file)[0]
with open(file, "w") as f:
f.write(
self.yaml(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
)
)

@classmethod
def from_file(cls, filename: str):
if not os.path.exists(filename):
raise OSError(f"file {filename} is not found")

with open(filename, "r") as file:
return cls.from_yaml(file)

@classmethod
def from_yaml(cls, yaml_obj: [str, TextIO]):
return cls.model_validate(yaml.safe_load(yaml_obj))


45 changes: 0 additions & 45 deletions lume_model/keras/README.md

This file was deleted.

4 changes: 0 additions & 4 deletions lume_model/keras/__init__.py

This file was deleted.

5 changes: 3 additions & 2 deletions lume_model/models.py → lume_model/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

# models requiring torch
try:
from lume_model.torch import TorchModel, TorchModule
from lume_model.models.torch_model import TorchModel
from lume_model.models.torch_module import TorchModule
registered_models += [TorchModel, TorchModule]
except ModuleNotFoundError:
pass

# models requiring keras
try:
from lume_model.keras import KerasModel
from lume_model.models.keras_model import KerasModel
registered_models += [KerasModel]
except ModuleNotFoundError:
pass
Expand Down
File renamed without changes.
Loading

0 comments on commit 73cb485

Please sign in to comment.