diff --git a/tests/conftest.py b/tests/conftest.py index 90788a045..f4cc2dfcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,11 +15,11 @@ import fmu.dataio as dio from fmu.config import utilities as ut -from fmu.dataio._model import fields, global_configuration +from fmu.dataio._model import Root, fields, global_configuration from fmu.dataio.dataio import ExportData, read_metadata from fmu.dataio.providers._fmu import FmuEnv -from .utils import _metadata_examples +from .utils import _get_nested_pydantic_models, _metadata_examples logger = logging.getLogger(__name__) @@ -699,3 +699,9 @@ def fixture_drogon_volumes(rootpath): rootpath / "tests/data/drogon/tabular/geogrid--vol.csv", ) ) + + +@pytest.fixture(scope="session") +def pydantic_models_from_root(): + """Return all nested pydantic models from Root and downwards""" + return _get_nested_pydantic_models(Root) diff --git a/tests/test_schema/test_pydantic_logic.py b/tests/test_schema/test_pydantic_logic.py index d7fc7fff6..85a40e2a4 100644 --- a/tests/test_schema/test_pydantic_logic.py +++ b/tests/test_schema/test_pydantic_logic.py @@ -2,13 +2,14 @@ import logging from copy import deepcopy +from typing import Literal, get_args, get_origin import pytest from pydantic import ValidationError -from fmu.dataio._model import Root, data +from fmu.dataio._model import Root, data, enums -from ..utils import _metadata_examples +from ..utils import _get_pydantic_models_from_annotation, _metadata_examples # pylint: disable=no-member @@ -32,6 +33,56 @@ def test_validate(file, example): Root.model_validate(example) +def test_for_optional_fields_without_default(pydantic_models_from_root): + """Test that all optional fields have a default value""" + optionals_without_default = [] + for model in pydantic_models_from_root: + for field_name, field_info in model.model_fields.items(): + if ( + type(None) in get_args(field_info.annotation) + and field_info.is_required() + ): + optionals_without_default.append( + f"{model.__module__}.{model.__name__}.{field_name}" + ) + + assert not optionals_without_default + + +def test_all_content_enums_in_anydata(): + """Test that all content enums are represented with a model in AnyData""" + anydata_models = _get_pydantic_models_from_annotation( + data.AnyData.model_fields["root"].annotation + ) + + content_enums_in_anydata = [] + for model in anydata_models: + # content is used as discriminator in AnyData and + # should be present for all models + assert "content" in model.model_fields + content_annotation = model.model_fields["content"].annotation + + # check that the annotation is a Literal + assert get_origin(content_annotation) == Literal + + # get_args will unpack the enum from the Literal + # into a tuple, should only be one Literal value + assert len(get_args(content_annotation)) == 1 + + # the literal value should be an enum + content_enum = get_args(content_annotation)[0] + assert isinstance(content_enum, enums.Content) + + content_enums_in_anydata.append(content_enum) + + # finally check that all content enums are represented + for content_enum in enums.Content: + assert content_enum in content_enums_in_anydata + + # and that number of models in AnyData matches number of content enums + assert len(anydata_models) == len(enums.Content) + + def test_schema_file_block(metadata_examples): """Test variations on the file block.""" diff --git a/tests/test_units/test_utils.py b/tests/test_units/test_utils.py index 1b24ef9cd..9e9c70b06 100644 --- a/tests/test_units/test_utils.py +++ b/tests/test_units/test_utils.py @@ -3,14 +3,16 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Dict, List, Optional, Union import numpy as np import pytest from xtgeo import Grid, Polygons, RegularSurface from fmu.dataio import _utils as utils +from fmu.dataio._model import fields -from ..utils import inside_rms +from ..utils import _get_pydantic_models_from_annotation, inside_rms @pytest.mark.parametrize( @@ -148,3 +150,29 @@ def test_read_named_envvar(): os.environ["MYTESTENV"] = "mytestvalue" assert utils.read_named_envvar("MYTESTENV") == "mytestvalue" + + +def test_get_pydantic_models_from_annotation(): + annotation = Union[List[fields.Access], fields.File] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + ] + annotation = Optional[Union[Dict[str, fields.Access], List[fields.File]]] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + ] + + annotation = List[Union[fields.Access, fields.File, fields.Tracklog]] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + fields.Tracklog, + ] + + annotation = List[List[List[List[fields.Tracklog]]]] + assert _get_pydantic_models_from_annotation(annotation) == [fields.Tracklog] + + annotation = Union[str, List[int], Dict[str, int]] + assert not _get_pydantic_models_from_annotation(annotation) diff --git a/tests/utils.py b/tests/utils.py index 42a490670..092371319 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import datetime from functools import wraps from pathlib import Path +from typing import Any, get_args import pytest import yaml +from pydantic import BaseModel def inside_rms(func): @@ -43,3 +47,28 @@ def _metadata_examples(): path.name: _isoformat_all_datetimes(_parse_yaml(path)) for path in Path(".").absolute().glob("schema/definitions/0.8.0/examples/*.yml") } + + +def _get_pydantic_models_from_annotation(annotation: Any) -> list[Any]: + """ + Get a list of all pydantic models defined inside an annotation. + Example: Union[Model1, list[dict[str, Model2]]] returns [Model1, Model2] + """ + if isinstance(annotation, type(BaseModel)): + return [annotation] + + annotations = [] + for ann in get_args(annotation): + annotations += _get_pydantic_models_from_annotation(ann) + return annotations + + +def _get_nested_pydantic_models(model: type[BaseModel]) -> set[type[BaseModel]]: + """Get a set of all nested pydantic models from a pydantic model""" + models = {model} + + for field_info in model.model_fields.values(): + for model in _get_pydantic_models_from_annotation(field_info.annotation): + if model not in models: + models.update(_get_nested_pydantic_models(model)) + return models