Skip to content

Commit

Permalink
TST: Test that all content enums are represented in AnyData
Browse files Browse the repository at this point in the history
  • Loading branch information
tnatt committed Oct 9, 2024
1 parent acb2f25 commit 323e195
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 5 deletions.
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import fmu.dataio as dio
from fmu.config import utilities as ut
from fmu.dataio._model import fields, global_configuration
from fmu.dataio._model import Root, fields, global_configuration
from fmu.dataio.dataio import ExportData, read_metadata
from fmu.dataio.providers._fmu import FmuEnv

from .utils import _metadata_examples
from .utils import _get_nested_pydantic_models, _metadata_examples

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -699,3 +699,9 @@ def fixture_drogon_volumes(rootpath):
rootpath / "tests/data/drogon/tabular/geogrid--vol.csv",
)
)


@pytest.fixture(scope="session")
def pydantic_models_from_root():
"""Return all nested pydantic models from Root and downwards"""
return _get_nested_pydantic_models(Root)
55 changes: 53 additions & 2 deletions tests/test_schema/test_pydantic_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import logging
from copy import deepcopy
from typing import Literal, get_args, get_origin

import pytest
from pydantic import ValidationError

from fmu.dataio._model import Root, data
from fmu.dataio._model import Root, data, enums

from ..utils import _metadata_examples
from ..utils import _get_pydantic_models_from_annotation, _metadata_examples

# pylint: disable=no-member

Expand All @@ -32,6 +33,56 @@ def test_validate(file, example):
Root.model_validate(example)


def test_for_optional_fields_without_default(pydantic_models_from_root):
"""Test that all optional fields have a default value"""
optionals_without_default = []
for model in pydantic_models_from_root:
for field_name, field_info in model.model_fields.items():
if (
type(None) in get_args(field_info.annotation)
and field_info.is_required()
):
optionals_without_default.append(
f"{model.__module__}.{model.__name__}.{field_name}"
)

assert not optionals_without_default


def test_all_content_enums_in_anydata():
"""Test that all content enums are represented with a model in AnyData"""
anydata_models = _get_pydantic_models_from_annotation(
data.AnyData.model_fields["root"].annotation
)

content_enums_in_anydata = []
for model in anydata_models:
# content is used as discriminator in AnyData and
# should be present for all models
assert "content" in model.model_fields
content_annotation = model.model_fields["content"].annotation

# check that the annotation is a Literal
assert get_origin(content_annotation) == Literal

# get_args will unpack the enum from the Literal
# into a tuple, should only be one Literal value
assert len(get_args(content_annotation)) == 1

# the literal value should be an enum
content_enum = get_args(content_annotation)[0]
assert isinstance(content_enum, enums.Content)

content_enums_in_anydata.append(content_enum)

# finally check that all content enums are represented
for content_enum in enums.Content:
assert content_enum in content_enums_in_anydata

# and that number of models in AnyData matches number of content enums
assert len(anydata_models) == len(enums.Content)


def test_schema_file_block(metadata_examples):
"""Test variations on the file block."""

Expand Down
30 changes: 29 additions & 1 deletion tests/test_units/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, Union

import numpy as np
import pytest
from xtgeo import Grid, Polygons, RegularSurface

from fmu.dataio import _utils as utils
from fmu.dataio._model import fields

from ..utils import inside_rms
from ..utils import _get_pydantic_models_from_annotation, inside_rms


@pytest.mark.parametrize(
Expand Down Expand Up @@ -148,3 +150,29 @@ def test_read_named_envvar():

os.environ["MYTESTENV"] = "mytestvalue"
assert utils.read_named_envvar("MYTESTENV") == "mytestvalue"


def test_get_pydantic_models_from_annotation():
annotation = Union[List[fields.Access], fields.File]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]
annotation = Optional[Union[Dict[str, fields.Access], List[fields.File]]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]

annotation = List[Union[fields.Access, fields.File, fields.Tracklog]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
fields.Tracklog,
]

annotation = List[List[List[List[fields.Tracklog]]]]
assert _get_pydantic_models_from_annotation(annotation) == [fields.Tracklog]

annotation = Union[str, List[int], Dict[str, int]]
assert not _get_pydantic_models_from_annotation(annotation)
29 changes: 29 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import datetime
from functools import wraps
from pathlib import Path
from typing import Any, get_args

import pytest
import yaml
from pydantic import BaseModel


def inside_rms(func):
Expand Down Expand Up @@ -43,3 +47,28 @@ def _metadata_examples():
path.name: _isoformat_all_datetimes(_parse_yaml(path))
for path in Path(".").absolute().glob("schema/definitions/0.8.0/examples/*.yml")
}


def _get_pydantic_models_from_annotation(annotation: Any) -> list[Any]:
"""
Get a list of all pydantic models defined inside an annotation.
Example: Union[Model1, list[dict[str, Model2]]] returns [Model1, Model2]
"""
if isinstance(annotation, type(BaseModel)):
return [annotation]

annotations = []
for ann in get_args(annotation):
annotations += _get_pydantic_models_from_annotation(ann)
return annotations


def _get_nested_pydantic_models(model: type[BaseModel]) -> set[type[BaseModel]]:
"""Get a set of all nested pydantic models from a pydantic model"""
models = {model}

for field_info in model.model_fields.values():
for model in _get_pydantic_models_from_annotation(field_info.annotation):
if model not in models:
models.update(_get_nested_pydantic_models(model))
return models

0 comments on commit 323e195

Please sign in to comment.