Skip to content

Commit

Permalink
Merge pull request #87 from t-bz/pydantic_v2
Browse files Browse the repository at this point in the history
Clean up and restructure model subpackages
  • Loading branch information
roussel-ryan authored Oct 25, 2023
2 parents 03b77dd + d0246f6 commit 0937301
Show file tree
Hide file tree
Showing 20 changed files with 401 additions and 415 deletions.
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ requirements:
- python
- setuptools
- pip
- pydantic==1.10.9
- pydantic>2.3
run:
- python
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml

Expand Down
2 changes: 1 addition & 1 deletion dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml
- tensorflow
Expand Down
68 changes: 43 additions & 25 deletions lume_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def json_dumps(
Args:
v: Object to dump.
default: Default for json.dumps().
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.
Expand Down Expand Up @@ -287,20 +286,23 @@ def validate_output_variables(cls, value):

return new_value

def __init__(
self,
config: Union[dict, str] = None,
**kwargs,
):
def __init__(self, *args, **kwargs):
"""Initializes LUMEBaseModel.
Args:
config: Model configuration as dictionary, YAML or JSON formatted string or file path. This overrides
all other arguments.
*args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
formatted string or file path.
**kwargs: See class attributes.
"""
if config is not None:
self.__init__(**parse_config(config))
if len(args) == 1:
if len(kwargs) > 0:
raise ValueError("Cannot specify YAML string and keyword arguments for LUMEBaseModel init.")
super().__init__(**parse_config(args[0]))
elif len(args) > 1:
raise ValueError(
"Arguments to LUMEBaseModel must be either a single YAML string "
"or keyword arguments passed directly to pydantic."
)
else:
super().__init__(**kwargs)

Expand Down Expand Up @@ -335,38 +337,54 @@ def json(self, **kwargs) -> str:

return json.dumps(config)

def yaml(self, **kwargs):
"""serialize first then dump to yaml string"""
def yaml(
self,
base_key: str = "",
file_prefix: str = "",
save_models: bool = False,
) -> str:
"""Serializes the object and returns a YAML formatted string defining the model.
Args:
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.
Returns:
YAML formatted string defining the model.
"""
output = json.loads(
self.to_json(
**kwargs,
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
)
)
return yaml.dump(output, default_flow_style=None, sort_keys=False)
s = yaml.dump({"model_class": self.__class__.__name__} | output,
default_flow_style=None, sort_keys=False)
return s

def dump(
self,
file: Union[str, os.PathLike],
save_models: bool = True,
base_key: str = "",
save_models: bool = True,
):
"""Returns and optionally saves YAML formatted string defining the model.
Args:
file: If not None, the YAML formatted string is saved to given file path.
save_models: Determines whether models are saved to file.
file: File path to which the YAML formatted string and corresponding files are saved.
base_key: Base key for serialization.
Returns:
YAML formatted string defining the model.
save_models: Determines whether models are saved to file.
"""
file_prefix = os.path.splitext(file)[0]

with open(file, "w") as f:
f.write(self.yaml(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models)
f.write(
self.yaml(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
)
)

@classmethod
Expand Down
45 changes: 0 additions & 45 deletions lume_model/keras/README.md

This file was deleted.

4 changes: 0 additions & 4 deletions lume_model/keras/__init__.py

This file was deleted.

5 changes: 3 additions & 2 deletions lume_model/models.py → lume_model/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

# models requiring torch
try:
from lume_model.torch import TorchModel, TorchModule
from lume_model.models.torch_model import TorchModel
from lume_model.models.torch_module import TorchModule
registered_models += [TorchModel, TorchModule]
except ModuleNotFoundError:
pass

# models requiring keras
try:
from lume_model.keras import KerasModel
from lume_model.models.keras_model import KerasModel
registered_models += [KerasModel]
except ModuleNotFoundError:
pass
Expand Down
File renamed without changes.
20 changes: 11 additions & 9 deletions lume_model/keras/model.py → lume_model/models/keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import keras
import numpy as np
from pydantic import validator
from pydantic import field_validator

from lume_model.base import LUMEBaseModel
from lume_model.variables import (
InputVariable,
OutputVariable,
ScalarInputVariable,
ScalarOutputVariable,
ImageOutputVariable,
# ScalarOutputVariable,
# ImageOutputVariable,
)

logger = logging.getLogger(__name__)
Expand All @@ -33,26 +33,28 @@ class KerasModel(LUMEBaseModel):

def __init__(
self,
config: Union[dict, str] = None,
*args,
**kwargs,
):
"""Initializes KerasModel.
Args:
config: Model configuration as dictionary, YAML or JSON formatted string or file path. This overrides
all other arguments.
*args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
formatted string or file path.
**kwargs: See class attributes.
"""
super().__init__(config, **kwargs)
super().__init__(*args, **kwargs)

@validator("model", pre=True)
@field_validator("model", mode="before")
def validate_keras_model(cls, v):
if isinstance(v, (str, os.PathLike)):
if os.path.exists(v):
v = keras.models.load_model(v)
else:
raise ValueError(f"Path {v} does not exist!")
return v

@validator("output_format")
@field_validator("output_format")
def validate_output_format(cls, v):
supported_formats = ["array", "variable", "raw"]
if v not in supported_formats:
Expand Down
29 changes: 11 additions & 18 deletions lume_model/torch/model.py → lume_model/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
from copy import deepcopy

import torch
import yaml
from pydantic import validator, field_validator
from pydantic import field_validator
from botorch.models.transforms.input import ReversibleInputTransform

from lume_model.base import LUMEBaseModel
from lume_model.variables import (
InputVariable,
OutputVariable,
ScalarInputVariable,
ScalarOutputVariable,
ImageOutputVariable,
# ScalarOutputVariable,
# ImageOutputVariable,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,20 +44,14 @@ class TorchModel(LUMEBaseModel):
fixed_model: bool = True

def __init__(self, *args, **kwargs):
"""Initializes TorchModel.
Args:
*args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
formatted string or file path.
**kwargs: See class attributes.
"""
Initialize Xopt.
"""
if len(args) == 1:
if len(kwargs) > 0:
raise ValueError("cannot specify yaml string and kwargs for Xopt init")
super().__init__(**yaml.safe_load(args[0]))
elif len(args) > 1:
raise ValueError(
"arguments to Xopt must be either a single yaml string "
"or a keyword arguments passed directly to pydantic"
)
else:
super().__init__(**kwargs)
super().__init__(*args, **kwargs)

# set precision
self.model.to(dtype=self.dtype)
Expand Down Expand Up @@ -90,7 +83,7 @@ def validate_torch_model(cls, v):
if os.path.exists(v):
v = torch.load(v)
else:
raise ValueError(f"path {v} does not exist!!")
raise ValueError(f"Path {v} does not exist!")
return v

@field_validator("input_transformers", "output_transformers", mode="before")
Expand Down
Loading

0 comments on commit 0937301

Please sign in to comment.