Skip to content

Commit

Permalink
CLN: Further split up ObjectData providers
Browse files Browse the repository at this point in the history
This structure makes adding new object types simpler, more maintainable,
and more scalable.
  • Loading branch information
mferrera committed Mar 25, 2024
1 parent 28aad46 commit 8349faa
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 184 deletions.
2 changes: 1 addition & 1 deletion src/fmu/dataio/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .datastructure.meta import meta
from .providers._filedata import FileDataProvider
from .providers._fmu import FmuProvider
from .providers._objectdata import objectdata_provider_factory
from .providers.objectdata._provider import objectdata_provider_factory
from .version import __version__

if TYPE_CHECKING:
Expand Down
4 changes: 1 addition & 3 deletions src/fmu/dataio/providers/_filedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
logger: Final = null_logger(__name__)

if TYPE_CHECKING:
from ._objectdata_provider import ObjectDataProvider
from .dataio import ExportData
from .objectdata._provider import ObjectDataProvider


@dataclass
Expand Down Expand Up @@ -177,8 +177,6 @@ def _get_path_generic(
info: str = "",
) -> Path:
"""Generically construct and get the folder path and verify."""
dest = None

outroot = deepcopy(self.rootpath)

logger.info("FMU context is %s", mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,15 @@

import pandas as pd
import xtgeo

from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats
from fmu.dataio._definitions import ValidFormats
from fmu.dataio._logging import null_logger
from fmu.dataio.datastructure.meta import specification

from ._objectdata_base import (
from ._base import (
DerivedObjectDescriptor,
ObjectDataProvider,
)
from ._objectdata_xtgeo import (
from ._tables import ArrowTableDataProvider, DataFrameDataProvider
from ._xtgeo import (
CPGridDataProvider,
CPGridPropertyDataProvider,
CubeDataProvider,
Expand All @@ -110,8 +109,6 @@
)

if TYPE_CHECKING:
import pyarrow

from fmu.dataio.dataio import ExportData
from fmu.dataio.types import Inferrable

Expand Down Expand Up @@ -160,76 +157,38 @@ def objectdata_provider_factory(


@dataclass
class DataFrameDataProvider(ObjectDataProvider):
obj: pd.DataFrame

def _check_index(self, index: list[str]) -> None:
"""Check the table index.
Args:
index (list): list of column names
Raises:
KeyError: if index contains names that are not in self
"""

not_founds = (item for item in index if item not in list(self.obj.columns))
for not_found in not_founds:
raise KeyError(f"{not_found} is not in table")

def _derive_index(self) -> list[str]:
"""Derive table index"""
# This could in the future also return context
columns = list(self.obj.columns)
index = []

if self.dataio.table_index is None:
logger.debug("Finding index to include")
for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items():
for valid_col in standard_cols:
if valid_col in columns:
index.append(valid_col)
if index:
logger.info("Context is %s ", context)
logger.debug("Proudly presenting the index: %s", index)
else:
index = self.dataio.table_index

if "REAL" in columns:
index.append("REAL")
self._check_index(index)
return index
class ExistingDataProvider(ObjectDataProvider):
"""These getters should never be called because metadata was derived a priori."""

def get_spec(self) -> dict[str, Any]:
"""Derive data.spec for pd.DataFrame."""
logger.info("Get spec for pd.DataFrame (tables)")

return specification.TableSpecification(
columns=list(self.obj.columns),
size=int(self.obj.size),
).model_dump(
mode="json",
exclude_none=True,
)
obj: Inferrable

def get_bbox(self) -> dict[str, Any]:
"""Derive data.bbox for pd.DataFrame."""
logger.info("Get bbox for pd.DataFrame (tables)")
return {}
def get_spec(self) -> dict:
"""Derive data.spec from existing metadata."""
return self.metadata["spec"]

def get_bbox(self) -> dict:
"""Derive data.bbox from existing metadata."""
return self.metadata["bbox"]

def get_objectdata(self) -> DerivedObjectDescriptor:
"""Derive object data for pd.DataFrame."""
"""Derive object data for existing metadata."""
return DerivedObjectDescriptor(
subtype="DataFrame",
classname="table",
layout="table",
efolder="tables",
fmt=(fmt := self.dataio.table_fformat),
extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table),
subtype=self.metadata["subtype"],
classname=self.metadata["class"],
layout=self.metadata["layout"],
efolder=self.efolder,
fmt=self.fmt,
extension=self.extension,
spec=self.get_spec(),
bbox=self.get_bbox(),
table_index=self._derive_index(),
table_index=None,
)

def derive_metadata(self) -> None:
"""Metadata has already been derived for this provider, and is already set from
instantiation, so override this method and do nothing."""
return


@dataclass
class DictionaryDataProvider(ObjectDataProvider):
Expand Down Expand Up @@ -258,108 +217,3 @@ def get_objectdata(self) -> DerivedObjectDescriptor:
bbox=self.get_bbox(),
table_index=None,
)


class ArrowTableDataProvider(ObjectDataProvider):
obj: pyarrow.Table

def _check_index(self, index: list[str]) -> None:
"""Check the table index.
Args:
index (list): list of column names
Raises:
KeyError: if index contains names that are not in self
"""

not_founds = (item for item in index if item not in self.obj.column_names)
for not_found in not_founds:
raise KeyError(f"{not_found} is not in table")

def _derive_index(self) -> list[str]:
"""Derive table index"""
# This could in the future also return context
columns = self.obj.column_names
index = []

if self.dataio.table_index is None:
logger.debug("Finding index to include")
for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items():
for valid_col in standard_cols:
if valid_col in columns:
index.append(valid_col)
if index:
logger.info("Context is %s ", context)
logger.debug("Proudly presenting the index: %s", index)
else:
index = self.dataio.table_index

if "REAL" in columns:
index.append("REAL")
self._check_index(index)
return index

def get_spec(self) -> dict[str, Any]:
"""Derive data.spec for pyarrow.Table."""
logger.info("Get spec for pyarrow (tables)")

return specification.TableSpecification(
columns=list(self.obj.column_names),
size=self.obj.num_columns * self.obj.num_rows,
).model_dump(
mode="json",
exclude_none=True,
)

def get_bbox(self) -> dict[str, Any]:
"""Derive data.bbox for pyarrow.Table."""
logger.info("Get bbox for pyarrow (tables)")
return {}

def get_objectdata(self) -> DerivedObjectDescriptor:
"""Derive object data from pyarrow.Table."""
return DerivedObjectDescriptor(
table_index=self._derive_index(),
subtype="ArrowTable",
classname="table",
layout="table",
efolder="tables",
fmt=(fmt := self.dataio.arrow_fformat),
extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table),
spec=self.get_spec(),
bbox=self.get_bbox(),
)


@dataclass
class ExistingDataProvider(ObjectDataProvider):
"""These getters should never be called because metadata was derived a priori."""

obj: Inferrable

def get_spec(self) -> dict:
"""Derive data.spec from existing metadata."""
return self.metadata["spec"]

def get_bbox(self) -> dict:
"""Derive data.bbox from existing metadata."""
return self.metadata["bbox"]

def get_objectdata(self) -> DerivedObjectDescriptor:
"""Derive object data for existing metadata."""
return DerivedObjectDescriptor(
subtype=self.metadata["subtype"],
classname=self.metadata["class"],
layout=self.metadata["layout"],
efolder=self.efolder,
fmt=self.fmt,
extension=self.extension,
spec=self.get_spec(),
bbox=self.get_bbox(),
table_index=None,
)

def derive_metadata(self) -> None:
"""Metadata has already been derived for this provider, and is already set from
instantiation, so override this method and do nothing."""
return
128 changes: 128 additions & 0 deletions src/fmu/dataio/providers/objectdata/_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Final

import pandas as pd
from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats
from fmu.dataio._logging import null_logger
from fmu.dataio.datastructure.meta import specification

from ._base import (
DerivedObjectDescriptor,
ObjectDataProvider,
)

if TYPE_CHECKING:
import pyarrow

logger: Final = null_logger(__name__)


def _check_index_in_columns(index: list[str], columns: list[str]) -> None:
"""Check the table index.
Args:
index (list): list of column names
Raises:
KeyError: if index contains names that are not in self
"""

not_founds = (item for item in index if item not in columns)
for not_found in not_founds:
raise KeyError(f"{not_found} is not in table")


def _derive_index(table_index: list[str] | None, columns: list[str]) -> list[str]:
index = []

if table_index is None:
logger.debug("Finding index to include")
for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items():
for valid_col in standard_cols:
if valid_col in columns:
index.append(valid_col)
if index:
logger.info("Context is %s ", context)
logger.debug("Proudly presenting the index: %s", index)
else:
index = table_index

if "REAL" in columns:
index.append("REAL")
_check_index_in_columns(index, columns)
return index


@dataclass
class DataFrameDataProvider(ObjectDataProvider):
obj: pd.DataFrame

def get_spec(self) -> dict:
"""Derive data.spec for pd.DataFrame."""
logger.info("Get spec for pd.DataFrame (tables)")

return specification.TableSpecification(
columns=list(self.obj.columns),
size=int(self.obj.size),
).model_dump(
mode="json",
exclude_none=True,
)

def get_bbox(self) -> dict:
"""Derive data.bbox for pd.DataFrame."""
logger.info("Get bbox for pd.DataFrame (tables)")
return {}

def get_objectdata(self) -> DerivedObjectDescriptor:
"""Derive object data for pd.DataFrame."""
table_index = _derive_index(self.dataio.table_index, list(self.obj.columns))
return DerivedObjectDescriptor(
subtype="DataFrame",
classname="table",
layout="table",
efolder="tables",
fmt=(fmt := self.dataio.table_fformat),
extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table),
spec=self.get_spec(),
bbox=self.get_bbox(),
table_index=table_index,
)


@dataclass
class ArrowTableDataProvider(ObjectDataProvider):
obj: pyarrow.Table

def get_spec(self) -> dict:
"""Derive data.spec for pyarrow.Table."""
logger.info("Get spec for pyarrow (tables)")

return specification.TableSpecification(
columns=list(self.obj.column_names),
size=self.obj.num_columns * self.obj.num_rows,
).model_dump(
mode="json",
exclude_none=True,
)

def get_bbox(self) -> dict:
"""Derive data.bbox for pyarrow.Table."""
logger.info("Get bbox for pyarrow (tables)")
return {}

def get_objectdata(self) -> DerivedObjectDescriptor:
"""Derive object data from pyarrow.Table."""
table_index = _derive_index(self.dataio.table_index, self.obj.column_names)
return DerivedObjectDescriptor(
subtype="ArrowTable",
classname="table",
layout="table",
efolder="tables",
fmt=(fmt := self.dataio.arrow_fformat),
extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table),
spec=self.get_spec(),
bbox=self.get_bbox(),
table_index=table_index,
)
Loading

0 comments on commit 8349faa

Please sign in to comment.