diff --git a/src/fmu/dataio/_metadata.py b/src/fmu/dataio/_metadata.py index 942891c04..806feb013 100644 --- a/src/fmu/dataio/_metadata.py +++ b/src/fmu/dataio/_metadata.py @@ -32,7 +32,7 @@ from .datastructure.meta import meta from .providers._filedata import FileDataProvider from .providers._fmu import FmuProvider -from .providers._objectdata import objectdata_provider_factory +from .providers.objectdata._provider import objectdata_provider_factory from .version import __version__ if TYPE_CHECKING: diff --git a/src/fmu/dataio/providers/_filedata.py b/src/fmu/dataio/providers/_filedata.py index f9e58ab0f..3bfbcbd20 100644 --- a/src/fmu/dataio/providers/_filedata.py +++ b/src/fmu/dataio/providers/_filedata.py @@ -18,8 +18,8 @@ logger: Final = null_logger(__name__) if TYPE_CHECKING: - from ._objectdata_provider import ObjectDataProvider from .dataio import ExportData + from .objectdata._provider import ObjectDataProvider @dataclass @@ -177,8 +177,6 @@ def _get_path_generic( info: str = "", ) -> Path: """Generically construct and get the folder path and verify.""" - dest = None - outroot = deepcopy(self.rootpath) logger.info("FMU context is %s", mode) diff --git a/src/fmu/dataio/providers/_objectdata_base.py b/src/fmu/dataio/providers/objectdata/_base.py similarity index 100% rename from src/fmu/dataio/providers/_objectdata_base.py rename to src/fmu/dataio/providers/objectdata/_base.py diff --git a/src/fmu/dataio/providers/_objectdata.py b/src/fmu/dataio/providers/objectdata/_provider.py similarity index 58% rename from src/fmu/dataio/providers/_objectdata.py rename to src/fmu/dataio/providers/objectdata/_provider.py index f5686f19a..201b83e3b 100644 --- a/src/fmu/dataio/providers/_objectdata.py +++ b/src/fmu/dataio/providers/objectdata/_provider.py @@ -91,16 +91,15 @@ import pandas as pd import xtgeo - -from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats +from fmu.dataio._definitions import ValidFormats from fmu.dataio._logging import null_logger -from fmu.dataio.datastructure.meta import specification -from ._objectdata_base import ( +from ._base import ( DerivedObjectDescriptor, ObjectDataProvider, ) -from ._objectdata_xtgeo import ( +from ._tables import ArrowTableDataProvider, DataFrameDataProvider +from ._xtgeo import ( CPGridDataProvider, CPGridPropertyDataProvider, CubeDataProvider, @@ -110,8 +109,6 @@ ) if TYPE_CHECKING: - import pyarrow - from fmu.dataio.dataio import ExportData from fmu.dataio.types import Inferrable @@ -160,76 +157,38 @@ def objectdata_provider_factory( @dataclass -class DataFrameDataProvider(ObjectDataProvider): - obj: pd.DataFrame - - def _check_index(self, index: list[str]) -> None: - """Check the table index. - Args: - index (list): list of column names - - Raises: - KeyError: if index contains names that are not in self - """ - - not_founds = (item for item in index if item not in list(self.obj.columns)) - for not_found in not_founds: - raise KeyError(f"{not_found} is not in table") - - def _derive_index(self) -> list[str]: - """Derive table index""" - # This could in the future also return context - columns = list(self.obj.columns) - index = [] - - if self.dataio.table_index is None: - logger.debug("Finding index to include") - for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items(): - for valid_col in standard_cols: - if valid_col in columns: - index.append(valid_col) - if index: - logger.info("Context is %s ", context) - logger.debug("Proudly presenting the index: %s", index) - else: - index = self.dataio.table_index - - if "REAL" in columns: - index.append("REAL") - self._check_index(index) - return index +class ExistingDataProvider(ObjectDataProvider): + """These getters should never be called because metadata was derived a priori.""" - def get_spec(self) -> dict[str, Any]: - """Derive data.spec for pd.DataFrame.""" - logger.info("Get spec for pd.DataFrame (tables)") - - return specification.TableSpecification( - columns=list(self.obj.columns), - size=int(self.obj.size), - ).model_dump( - mode="json", - exclude_none=True, - ) + obj: Inferrable - def get_bbox(self) -> dict[str, Any]: - """Derive data.bbox for pd.DataFrame.""" - logger.info("Get bbox for pd.DataFrame (tables)") - return {} + def get_spec(self) -> dict: + """Derive data.spec from existing metadata.""" + return self.metadata["spec"] + + def get_bbox(self) -> dict: + """Derive data.bbox from existing metadata.""" + return self.metadata["bbox"] def get_objectdata(self) -> DerivedObjectDescriptor: - """Derive object data for pd.DataFrame.""" + """Derive object data for existing metadata.""" return DerivedObjectDescriptor( - subtype="DataFrame", - classname="table", - layout="table", - efolder="tables", - fmt=(fmt := self.dataio.table_fformat), - extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table), + subtype=self.metadata["subtype"], + classname=self.metadata["class"], + layout=self.metadata["layout"], + efolder=self.efolder, + fmt=self.fmt, + extension=self.extension, spec=self.get_spec(), bbox=self.get_bbox(), - table_index=self._derive_index(), + table_index=None, ) + def derive_metadata(self) -> None: + """Metadata has already been derived for this provider, and is already set from + instantiation, so override this method and do nothing.""" + return + @dataclass class DictionaryDataProvider(ObjectDataProvider): @@ -258,108 +217,3 @@ def get_objectdata(self) -> DerivedObjectDescriptor: bbox=self.get_bbox(), table_index=None, ) - - -class ArrowTableDataProvider(ObjectDataProvider): - obj: pyarrow.Table - - def _check_index(self, index: list[str]) -> None: - """Check the table index. - Args: - index (list): list of column names - - Raises: - KeyError: if index contains names that are not in self - """ - - not_founds = (item for item in index if item not in self.obj.column_names) - for not_found in not_founds: - raise KeyError(f"{not_found} is not in table") - - def _derive_index(self) -> list[str]: - """Derive table index""" - # This could in the future also return context - columns = self.obj.column_names - index = [] - - if self.dataio.table_index is None: - logger.debug("Finding index to include") - for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items(): - for valid_col in standard_cols: - if valid_col in columns: - index.append(valid_col) - if index: - logger.info("Context is %s ", context) - logger.debug("Proudly presenting the index: %s", index) - else: - index = self.dataio.table_index - - if "REAL" in columns: - index.append("REAL") - self._check_index(index) - return index - - def get_spec(self) -> dict[str, Any]: - """Derive data.spec for pyarrow.Table.""" - logger.info("Get spec for pyarrow (tables)") - - return specification.TableSpecification( - columns=list(self.obj.column_names), - size=self.obj.num_columns * self.obj.num_rows, - ).model_dump( - mode="json", - exclude_none=True, - ) - - def get_bbox(self) -> dict[str, Any]: - """Derive data.bbox for pyarrow.Table.""" - logger.info("Get bbox for pyarrow (tables)") - return {} - - def get_objectdata(self) -> DerivedObjectDescriptor: - """Derive object data from pyarrow.Table.""" - return DerivedObjectDescriptor( - table_index=self._derive_index(), - subtype="ArrowTable", - classname="table", - layout="table", - efolder="tables", - fmt=(fmt := self.dataio.arrow_fformat), - extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table), - spec=self.get_spec(), - bbox=self.get_bbox(), - ) - - -@dataclass -class ExistingDataProvider(ObjectDataProvider): - """These getters should never be called because metadata was derived a priori.""" - - obj: Inferrable - - def get_spec(self) -> dict: - """Derive data.spec from existing metadata.""" - return self.metadata["spec"] - - def get_bbox(self) -> dict: - """Derive data.bbox from existing metadata.""" - return self.metadata["bbox"] - - def get_objectdata(self) -> DerivedObjectDescriptor: - """Derive object data for existing metadata.""" - return DerivedObjectDescriptor( - subtype=self.metadata["subtype"], - classname=self.metadata["class"], - layout=self.metadata["layout"], - efolder=self.efolder, - fmt=self.fmt, - extension=self.extension, - spec=self.get_spec(), - bbox=self.get_bbox(), - table_index=None, - ) - - def derive_metadata(self) -> None: - """Metadata has already been derived for this provider, and is already set from - instantiation, so override this method and do nothing.""" - return diff --git a/src/fmu/dataio/providers/objectdata/_tables.py b/src/fmu/dataio/providers/objectdata/_tables.py new file mode 100644 index 000000000..28d9edfaf --- /dev/null +++ b/src/fmu/dataio/providers/objectdata/_tables.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Final + +import pandas as pd +from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats +from fmu.dataio._logging import null_logger +from fmu.dataio.datastructure.meta import specification + +from ._base import ( + DerivedObjectDescriptor, + ObjectDataProvider, +) + +if TYPE_CHECKING: + import pyarrow + +logger: Final = null_logger(__name__) + + +def _check_index_in_columns(index: list[str], columns: list[str]) -> None: + """Check the table index. + Args: + index (list): list of column names + + Raises: + KeyError: if index contains names that are not in self + """ + + not_founds = (item for item in index if item not in columns) + for not_found in not_founds: + raise KeyError(f"{not_found} is not in table") + + +def _derive_index(table_index: list[str] | None, columns: list[str]) -> list[str]: + index = [] + + if table_index is None: + logger.debug("Finding index to include") + for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items(): + for valid_col in standard_cols: + if valid_col in columns: + index.append(valid_col) + if index: + logger.info("Context is %s ", context) + logger.debug("Proudly presenting the index: %s", index) + else: + index = table_index + + if "REAL" in columns: + index.append("REAL") + _check_index_in_columns(index, columns) + return index + + +@dataclass +class DataFrameDataProvider(ObjectDataProvider): + obj: pd.DataFrame + + def get_spec(self) -> dict: + """Derive data.spec for pd.DataFrame.""" + logger.info("Get spec for pd.DataFrame (tables)") + + return specification.TableSpecification( + columns=list(self.obj.columns), + size=int(self.obj.size), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict: + """Derive data.bbox for pd.DataFrame.""" + logger.info("Get bbox for pd.DataFrame (tables)") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for pd.DataFrame.""" + table_index = _derive_index(self.dataio.table_index, list(self.obj.columns)) + return DerivedObjectDescriptor( + subtype="DataFrame", + classname="table", + layout="table", + efolder="tables", + fmt=(fmt := self.dataio.table_fformat), + extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=table_index, + ) + + +@dataclass +class ArrowTableDataProvider(ObjectDataProvider): + obj: pyarrow.Table + + def get_spec(self) -> dict: + """Derive data.spec for pyarrow.Table.""" + logger.info("Get spec for pyarrow (tables)") + + return specification.TableSpecification( + columns=list(self.obj.column_names), + size=self.obj.num_columns * self.obj.num_rows, + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict: + """Derive data.bbox for pyarrow.Table.""" + logger.info("Get bbox for pyarrow (tables)") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data from pyarrow.Table.""" + table_index = _derive_index(self.dataio.table_index, self.obj.column_names) + return DerivedObjectDescriptor( + subtype="ArrowTable", + classname="table", + layout="table", + efolder="tables", + fmt=(fmt := self.dataio.arrow_fformat), + extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=table_index, + ) diff --git a/src/fmu/dataio/providers/_objectdata_xtgeo.py b/src/fmu/dataio/providers/objectdata/_xtgeo.py similarity index 99% rename from src/fmu/dataio/providers/_objectdata_xtgeo.py rename to src/fmu/dataio/providers/objectdata/_xtgeo.py index 8d156e993..8d6c67cd7 100644 --- a/src/fmu/dataio/providers/_objectdata_xtgeo.py +++ b/src/fmu/dataio/providers/objectdata/_xtgeo.py @@ -6,13 +6,12 @@ import numpy as np import pandas as pd import xtgeo - from fmu.dataio._definitions import ValidFormats from fmu.dataio._logging import null_logger from fmu.dataio._utils import npfloat_to_float from fmu.dataio.datastructure.meta import meta, specification -from ._objectdata_base import ( +from ._base import ( DerivedObjectDescriptor, ObjectDataProvider, ) diff --git a/tests/test_units/test_filedataprovider_class.py b/tests/test_units/test_filedataprovider_class.py index 13735863e..33a7b71a6 100644 --- a/tests/test_units/test_filedataprovider_class.py +++ b/tests/test_units/test_filedataprovider_class.py @@ -6,8 +6,8 @@ import pytest from fmu.dataio import ExportData from fmu.dataio.providers._filedata import FileDataProvider -from fmu.dataio.providers._objectdata import objectdata_provider_factory -from fmu.dataio.providers._objectdata_base import derive_name +from fmu.dataio.providers.objectdata._base import derive_name +from fmu.dataio.providers.objectdata._provider import objectdata_provider_factory from xtgeo.cube import Cube from xtgeo.surface import RegularSurface diff --git a/tests/test_units/test_objectdataprovider_class.py b/tests/test_units/test_objectdataprovider_class.py index 932f92cfa..032cbf9fc 100644 --- a/tests/test_units/test_objectdataprovider_class.py +++ b/tests/test_units/test_objectdataprovider_class.py @@ -6,11 +6,11 @@ from fmu.dataio import dataio from fmu.dataio._definitions import ConfigurationError, ValidFormats from fmu.dataio._metadata import MetaData -from fmu.dataio.providers._objectdata import ( +from fmu.dataio.providers.objectdata._provider import ( ExistingDataProvider, objectdata_provider_factory, ) -from fmu.dataio.providers._objectdata_xtgeo import RegularSurfaceDataProvider +from fmu.dataio.providers.objectdata._xtgeo import RegularSurfaceDataProvider from ..utils import inside_rms diff --git a/tests/test_units/test_table.py b/tests/test_units/test_table.py index 1c646c739..12c923a99 100644 --- a/tests/test_units/test_table.py +++ b/tests/test_units/test_table.py @@ -5,7 +5,7 @@ import pytest from fmu.config.utilities import yaml_load from fmu.dataio import ExportData -from fmu.dataio.providers._objectdata import objectdata_provider_factory +from fmu.dataio.providers.objectdata._provider import objectdata_provider_factory def _read_dict(file_path):