Skip to content

Commit

Permalink
CLN: Improve _derive_timedata type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
JB Lovland committed Feb 12, 2024
1 parent 84157ae commit eaa8b72
Showing 1 changed file with 64 additions and 66 deletions.
130 changes: 64 additions & 66 deletions src/fmu/dataio/_objectdata_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"""
from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from datetime import datetime as dt
from pathlib import Path
from typing import Any, Dict, Final, NamedTuple, Optional
Expand Down Expand Up @@ -116,6 +116,30 @@ def npfloat_to_float(v: Any) -> Any:
return float(v) if isinstance(v, (np.float64, np.float32)) else v


@dataclass
class TimedataValueLabel:
value: str
label: str = field(default="")

@staticmethod
def from_list(arr: list) -> TimedataValueLabel:
return TimedataValueLabel(
value=dt.strptime(str(arr[0]), "%Y%m%d").isoformat(),
label=arr[1] if len(arr) == 2 else "",
)


@dataclass
class TimedataOldFormat:
time: list[TimedataValueLabel]


@dataclass
class TimedataNewFormat:
t0: Optional[TimedataValueLabel]
t1: Optional[TimedataValueLabel]


@dataclass
class _ObjectDataProvider:
"""Class for providing metadata for data objects in fmu-dataio, e.g. a surface.
Expand Down Expand Up @@ -635,59 +659,45 @@ def _check_index(self, index: list[str]) -> None:
for not_found in not_founds:
raise KeyError(f"{not_found} is not in table")

def _derive_timedata(self) -> dict:
def _derive_timedata(self) -> TimedataNewFormat | TimedataOldFormat:
"""Format input timedata to metadata."""

tdata = self.dataio.timedata
if not tdata:
return {}
return TimedataNewFormat(None, None)

if self.dataio.legacy_time_format:
timedata = self._derive_timedata_legacy()
else:
timedata = self._derive_timedata_newformat()
return timedata
return (
self._derive_timedata_legacy()
if self.dataio.legacy_time_format
else self._derive_timedata_newformat()
)

def _derive_timedata_legacy(self) -> dict[str, Any]:
def _derive_timedata_legacy(self) -> TimedataOldFormat:
"""Format input timedata to metadata. legacy version."""
# TODO(JB): Covnert tresult to TypedDict or Dataclass.
tdata = self.dataio.timedata

tresult: dict[str, Any] = {}
tresult["time"] = []
if len(tdata) == 1:
elem = tdata[0]
tresult["time"] = []
xfield = {"value": dt.strptime(str(elem[0]), "%Y%m%d").isoformat()}
self.time0 = str(elem[0])
if len(elem) == 2:
xfield["label"] = elem[1]
tresult["time"].append(xfield)
start = TimedataValueLabel.from_list(tdata[0])
self.time0 = start.value
return TimedataOldFormat([start])

if len(tdata) == 2:
elem1 = tdata[0]
xfield1 = {"value": dt.strptime(str(elem1[0]), "%Y%m%d").isoformat()}
if len(elem1) == 2:
xfield1["label"] = elem1[1]

elem2 = tdata[1]
xfield2 = {"value": dt.strptime(str(elem2[0]), "%Y%m%d").isoformat()}
if len(elem2) == 2:
xfield2["label"] = elem2[1]

if xfield1["value"] < xfield2["value"]:
tresult["time"].append(xfield1)
tresult["time"].append(xfield2)
else:
tresult["time"].append(xfield2)
tresult["time"].append(xfield1)
start, stop = (
TimedataValueLabel.from_list(tdata[0]),
TimedataValueLabel.from_list(tdata[1]),
)
if start.value > stop.value:
start, stop = stop, start

self.time0 = tresult["time"][0]["value"]
self.time1 = tresult["time"][1]["value"]
self.time0, self.time1 = start.value, stop.value

logger.info("Timedata: time0 is %s while time1 is %s", self.time0, self.time1)
return tresult
return TimedataOldFormat([start, stop])

def _derive_timedata_newformat(self) -> dict[str, Any]:
return TimedataOldFormat([])

def _derive_timedata_newformat(self) -> TimedataNewFormat:
"""Format input timedata to metadata, new format.
When using two dates, input convention is [[newestdate, "monitor"], [oldestdate,
Expand All @@ -700,36 +710,24 @@ def _derive_timedata_newformat(self) -> dict[str, Any]:
tresult: dict[str, Any] = {}

if len(tdata) == 1:
elem = tdata[0]
tresult["t0"] = {}
xfield = {"value": dt.strptime(str(elem[0]), "%Y%m%d").isoformat()}
self.time0 = str(elem[0])
if len(elem) == 2:
xfield["label"] = elem[1]
tresult["t0"] = xfield
start = TimedataValueLabel.from_list(tdata[0])
self.time0 = start.value
return TimedataNewFormat(t0=start, t1=None)

if len(tdata) == 2:
elem1 = tdata[0]
xfield1 = {"value": dt.strptime(str(elem1[0]), "%Y%m%d").isoformat()}
if len(elem1) == 2:
xfield1["label"] = elem1[1]

elem2 = tdata[1]
xfield2 = {"value": dt.strptime(str(elem2[0]), "%Y%m%d").isoformat()}
if len(elem2) == 2:
xfield2["label"] = elem2[1]

if xfield1["value"] < xfield2["value"]:
tresult["t0"] = xfield1
tresult["t1"] = xfield2
else:
tresult["t0"] = xfield2
tresult["t1"] = xfield1
start, stop = (
TimedataValueLabel.from_list(tdata[0]),
TimedataValueLabel.from_list(tdata[1]),
)

if start.value > stop.value:
start, stop = stop, start

self.time0 = tresult["t0"]["value"]
self.time1 = tresult["t1"]["value"]
self.time0 = start.value
self.time1 = stop.value
return TimedataNewFormat(start, stop)

logger.info("Timedata: time0 is %s while time1 is %s", self.time0, self.time1)
return tresult
return TimedataNewFormat(None, None)

def _derive_from_existing(self) -> None:
"""Derive from existing metadata."""
Expand Down Expand Up @@ -821,7 +819,7 @@ def derive_metadata(self) -> None:
meta["undef_is_zero"] = self.dataio.undef_is_zero

# timedata:
tresult = self._derive_timedata()
tresult = asdict(self._derive_timedata())
if tresult:
if self.dataio.legacy_time_format:
for key, val in tresult.items():
Expand Down

0 comments on commit eaa8b72

Please sign in to comment.