diff --git a/.dmypy.json b/.dmypy.json new file mode 100644 index 000000000..7e1c1a65b --- /dev/null +++ b/.dmypy.json @@ -0,0 +1 @@ +{"pid": 29548, "connection_name": "/var/folders/yj/3l_bv6h5763871t8jl_m7xz40000gp/T/tmpyyp2m3jv/dmypy.sock"} diff --git a/examples/s/d/nn/_project/aggregate_surfaces.py b/examples/s/d/nn/_project/aggregate_surfaces.py index a0cac8819..8bbdc27b5 100644 --- a/examples/s/d/nn/_project/aggregate_surfaces.py +++ b/examples/s/d/nn/_project/aggregate_surfaces.py @@ -67,7 +67,6 @@ def main(): ) for operation in operations: - print(f"Running aggregation: {operation}") # Call the aggregation machine and create an aggregated surface diff --git a/examples/s/d/nn/xcase/realization-0/iter-0/rms/bin/export_faultpolygons.py b/examples/s/d/nn/xcase/realization-0/iter-0/rms/bin/export_faultpolygons.py index 89da52c8f..8fc91266e 100644 --- a/examples/s/d/nn/xcase/realization-0/iter-0/rms/bin/export_faultpolygons.py +++ b/examples/s/d/nn/xcase/realization-0/iter-0/rms/bin/export_faultpolygons.py @@ -34,7 +34,6 @@ def export_faultlines(): ) for hname in HORISONNAMES: - # RMS version for reading polygons from a project: # poly = xtgeo.polygons_from_roxar(project, hname, RMS_POL_CATEGORY) diff --git a/mypy.ini b/mypy.ini index 467a0d254..2f657da80 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,12 @@ [mypy] +disallow_untyped_defs = True +extra_checks = True +ignore_missing_imports = True +strict_equality = True +warn_redundant_casts = True +warn_unused_configs = True +warn_unused_ignores = True +exclude = ^((tests|docs|examples|bin)/|conftest.py?) [mypy-numpy.*] # Applies to Python 3.6: diff --git a/pyproject.toml b/pyproject.toml index 24f7c566a..d3995fcb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,12 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "xtgeo>=2.16", - "PyYAML", - "pyarrow", "fmu-config>=1.1.0", - "pandas", "numpy", + "pandas", + "pyarrow", + "PyYAML", + "xtgeo>=2.16", ] [project.urls] @@ -49,27 +49,28 @@ dev = [ "black", "coverage>=4.1", "flake8", + "hypothesis", "isort", "jsonschema", - "hypothesis", "mypy", + "pydocstyle", "pylint", - "pytest", "pytest-cov", - "pydocstyle", - "pytest-runner", "pytest-mock", - "termcolor", + "pytest-runner", + "pytest", "rstcheck", + "termcolor", + "types-PyYAML", ] docs = [ - "pydocstyle", - "Sphinx<7", "autoapi", - "sphinx-rtd-theme", + "pydocstyle", "sphinx-autodoc-typehints<1.23", - "sphinxcontrib-apidoc", + "sphinx-rtd-theme", "sphinx-togglebutton", + "Sphinx<7", + "sphinxcontrib-apidoc", "urllib3<1.27", ] diff --git a/src/fmu/dataio/_definitions.py b/src/fmu/dataio/_definitions.py index 9ff6785e0..a3190251f 100644 --- a/src/fmu/dataio/_definitions.py +++ b/src/fmu/dataio/_definitions.py @@ -18,7 +18,7 @@ class _ValidFormats: points: dict = field(default_factory=dict) dictionary: dict = field(default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: self.surface = {"irap_binary": ".gri"} self.grid = {"hdf": ".hdf", "roff": ".roff"} self.cube = {"segy": ".segy"} diff --git a/src/fmu/dataio/_design_kw.py b/src/fmu/dataio/_design_kw.py index d3922b26d..3a0f5b4e3 100644 --- a/src/fmu/dataio/_design_kw.py +++ b/src/fmu/dataio/_design_kw.py @@ -9,6 +9,7 @@ import logging import shlex import re +from typing import Iterable _STATUS_FILE_NAME = "DESIGN_KW.OK" @@ -17,11 +18,11 @@ def run( - template_file_name, - result_file_name, - log_level, - parameters_file_name="parameters.txt", -): + template_file_name: str, + result_file_name: str, + log_level: logging._Level, + parameters_file_name: str = "parameters.txt", +) -> None: # Get all key, value pairs # If FWL key is having multiple entries in the parameters file # KeyError is raised. This will be logged, and no OK @@ -29,8 +30,6 @@ def run( _logger.setLevel(log_level) - valid = True - with open(parameters_file_name) as parameters_file: parameters = parameters_file.readlines() @@ -41,24 +40,22 @@ def run( with open(template_file_name, "r") as template_file: template = template_file.readlines() - if valid: - with open(result_file_name, "w") as result_file: - for line in template: - if not is_comment(line): - for key, value in key_vals.items(): - line = line.replace(f"<{key}>", str(value)) + with open(result_file_name, "w") as result_file: + for line in template: + if not is_comment(line): + for key, value in key_vals.items(): + line = line.replace(f"<{key}>", str(value)) - if not all_matched(line, template_file_name, template): - valid = False + if not all_matched(line, template_file_name, template): + valid = False - result_file.write(line) + result_file.write(line) - if valid: - with open(_STATUS_FILE_NAME, "w") as status_file: - status_file.write("DESIGN_KW OK\n") + with open(_STATUS_FILE_NAME, "w") as status_file: + status_file.write("DESIGN_KW OK\n") -def all_matched(line, template_file_name, template): +def all_matched(line: str, template_file_name: str, template: list[str]) -> bool: valid = True for unmatched in unmatched_templates(line): if is_perl(template_file_name, template): @@ -76,25 +73,24 @@ def all_matched(line, template_file_name, template): return valid -def is_perl(file_name, template): - return file_name.endswith(".pl") or template[0].find("perl") != -1 +def is_perl(file_name: str, template: list[str]) -> bool: + return bool(file_name.endswith(".pl") or template[0].find("perl") != -1) -def unmatched_templates(line): +def unmatched_templates(line: str) -> list[str]: bracketpattern = re.compile("<.+?>") if bracketpattern.search(line): return bracketpattern.findall(line) - else: - return [] + return [] -def is_comment(line): +def is_comment(line: str) -> bool: ecl_comment_pattern = re.compile("^--") std_comment_pattern = re.compile("^#") - return ecl_comment_pattern.search(line) or std_comment_pattern.search(line) + return bool(ecl_comment_pattern.search(line) or std_comment_pattern.search(line)) -def extract_key_value(parameters): +def extract_key_value(parameters: Iterable[str]) -> dict[str, str]: """Parses a list of strings, looking for key-value pairs pr. line separated by whitespace, into a dictionary. @@ -132,7 +128,10 @@ def extract_key_value(parameters): return res -def rm_genkw_prefix(paramsdict, ignoreprefixes="LOG10_"): +def rm_genkw_prefix( + paramsdict: dict[str, object], + ignoreprefixes: str | list[str] | None = "LOG10_", +) -> dict[str, object]: """Strip prefixes from keys in a dictionary. Prefix is any string before a colon. No colon means no prefix. @@ -156,7 +155,8 @@ def rm_genkw_prefix(paramsdict, ignoreprefixes="LOG10_"): ignoreprefixes = [] if isinstance(ignoreprefixes, str): ignoreprefixes = [ignoreprefixes] - ignoreprefixes = filter(None, ignoreprefixes) + + ignoreprefixes = list(filter(None, ignoreprefixes)) for ignore_str in ignoreprefixes: paramsdict = { diff --git a/src/fmu/dataio/_filedata_provider.py b/src/fmu/dataio/_filedata_provider.py index 803ea8055..db2fa78b3 100644 --- a/src/fmu/dataio/_filedata_provider.py +++ b/src/fmu/dataio/_filedata_provider.py @@ -41,7 +41,7 @@ class _FileDataProvider: absolute_path_symlink: Optional[str] = field(default="", init=False) checksum_md5: Optional[str] = field(default="", init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) if self.dataio.name: @@ -63,10 +63,11 @@ def __post_init__(self): self.fmu_context = self.dataio._usecontext # may be None! - logger.info("Initialize %s", __class__) + logger.info("Initialize %s", self.__class__) - def derive_filedata(self): + def derive_filedata(self) -> None: relpath, symrelpath = self._get_path() + assert relpath is not None relative, absolute = self._derive_filedata_generic(relpath) self.relative_path = relative self.absolute_path = absolute @@ -78,7 +79,7 @@ def derive_filedata(self): logger.info("Derived filedata") - def _derive_filedata_generic(self, inrelpath): + def _derive_filedata_generic(self, inrelpath: Path) -> tuple[str, str]: """This works with both normal data and symlinks.""" stem = self._get_filestem() @@ -116,7 +117,7 @@ def _derive_filedata_generic(self, inrelpath): logger.info("Derived filedata") return str(relpath), str(abspath) - def _get_filestem(self): + def _get_filestem(self) -> str: """Construct the file""" if not self.name: @@ -153,15 +154,15 @@ def _get_filestem(self): stem = stem.replace("__", "_") # treat norwegian special letters + # BUG(?): What about germen letter like "Ü"? stem = stem.replace("æ", "ae") stem = stem.replace("ø", "oe") stem = stem.replace("å", "aa") return stem - def _get_path(self): + def _get_path(self) -> tuple[Path, Path | None]: """Construct and get the folder path(s).""" - dest = None linkdest = None dest = self._get_path_generic(mode=self.fmu_context, allow_forcefolder=True) @@ -173,7 +174,9 @@ def _get_path(self): return dest, linkdest - def _get_path_generic(self, mode="realization", allow_forcefolder=True, info=""): + def _get_path_generic( + self, mode: str = "realization", allow_forcefolder: bool = True, info: str = "" + ) -> Path: """Generically construct and get the folder path and verify.""" dest = None @@ -215,8 +218,7 @@ def _get_path_generic(self, mode="realization", allow_forcefolder=True, info="") warn("Using absolute paths in forcefolder is not recommended!") # absolute if starts with "/", otherwise relative to outroot - dest = Path(self.dataio.forcefolder) - dest = dest.absolute() + dest = Path(self.dataio.forcefolder).absolute() self.forcefolder_is_absolute = True if not allow_forcefolder: diff --git a/src/fmu/dataio/_fmu_provider.py b/src/fmu/dataio/_fmu_provider.py index 94adbf57c..9ddf9cc09 100644 --- a/src/fmu/dataio/_fmu_provider.py +++ b/src/fmu/dataio/_fmu_provider.py @@ -65,16 +65,16 @@ class _FmuProvider: metadata: dict = field(default_factory=dict, init=False) rootpath: Optional[Path] = field(default=None, init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) self.rootpath = Path(self.dataio._rootpath.absolute()) self.rootpath_initial = self.rootpath - logger.info("Initialize %s", __class__) + logger.info("Initialize %s", self.__class__) - def detect_provider(self): + def detect_provider(self) -> None: """First order method to detect provider, ans also check fmu_context.""" if self._detect_ert2provider() or self._detect_ert2provider_case_only(): self.provider = "ERT2" @@ -188,7 +188,7 @@ def _detect_ert2provider_case_only(self) -> bool: return True return False - def get_ert2_information(self): + def get_ert2_information(self) -> None: """Retrieve information from an ERT2 run.""" if not self.iter_path: return @@ -197,6 +197,8 @@ def get_ert2_information(self): parameters_file = self.iter_path / "parameters.txt" if parameters_file.is_file(): params = _utils.read_parameters_txt(parameters_file) + # BUG(?): value can contain Nones, loop in fn. below + # does contains check, will fail. nested_params = _utils.nested_parameters_dict(params) self.ert2["params"] = nested_params logger.debug("parameters.txt parsed.") @@ -245,13 +247,14 @@ def get_ert2_information(self): logger.debug("ERT files has been parsed.") - def get_ert2_case_metadata(self): + def get_ert2_case_metadata(self) -> None: """Check if metadatafile file for CASE exists, and if so parse metadata. If file does not exist, still give a proposed file path, but the self.case_metadata will be {} (empty) and the physical file will not be made. """ + assert self.rootpath is not None self.case_metafile = self.rootpath / ERT2_RELATIVE_CASE_METADATA_FILE self.case_metafile = self.case_metafile.resolve() if self.case_metafile.exists(): @@ -263,7 +266,7 @@ def get_ert2_case_metadata(self): "Case metadata file does not exists as %s", str(self.case_metafile) ) - def generate_ert2_metadata(self): + def generate_ert2_metadata(self) -> None: """Construct the metadata FMU block for an ERT2 forward job.""" logger.info("Generate ERT2 metadata...") diff --git a/src/fmu/dataio/_metadata.py b/src/fmu/dataio/_metadata.py index 801625ed5..9bfe2e2fb 100644 --- a/src/fmu/dataio/_metadata.py +++ b/src/fmu/dataio/_metadata.py @@ -42,7 +42,7 @@ def default_meta_dollars() -> dict: return dollars -def generate_meta_tracklog() -> list: +def generate_meta_tracklog() -> list[dict]: """Create the tracklog metadata, which here assumes 'created' only.""" meta = list() @@ -218,7 +218,7 @@ class _MetaData: # if re-using existing metadata meta_existing: dict = field(default_factory=dict, init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) logger.info("Initialize _MetaData instance.") @@ -229,7 +229,7 @@ def __post_init__(self): logger.info("Partially reuse existing metadata from %s", self.obj) self.meta_existing = read_metadata(self.obj) - def _populate_meta_objectdata(self): + def _populate_meta_objectdata(self) -> None: """Analyze the actual object together with input settings. This will provide input to the ``data`` block of the metas but has also @@ -254,7 +254,7 @@ def _get_case_metadata(self): logger.info("FMU provider is %s", self.fmudata.provider) return self.fmudata.case_metadata - def _populate_meta_fmu(self): + def _populate_meta_fmu(self) -> None: """Populate the fmu block in the metadata. This block may be missing in case the client is not within a FMU run, e.g. @@ -268,7 +268,7 @@ def _populate_meta_fmu(self): self.meta_fmu = self.fmudata.metadata self.rootpath = self.fmudata.rootpath - def _populate_meta_file(self): + def _populate_meta_file(self) -> None: """Populate the file block in the metadata. The file block also contains all needed info for doing the actual file export. @@ -313,19 +313,19 @@ def _populate_meta_file(self): logger.info("Do not compute MD5 sum at this stage!") self.meta_file["checksum_md5"] = None - def _populate_meta_class(self): + def _populate_meta_class(self) -> None: """Get the general class which is a simple string.""" self.meta_class = self.objdata.classname - def _populate_meta_tracklog(self): + def _populate_meta_tracklog(self) -> None: """Create the tracklog metadata, which here assumes 'created' only.""" self.meta_tracklog = generate_meta_tracklog() - def _populate_meta_masterdata(self): + def _populate_meta_masterdata(self) -> None: """Populate metadata from masterdata section in config.""" self.meta_masterdata = generate_meta_masterdata(self.dataio.config) - def _populate_meta_access(self): + def _populate_meta_access(self) -> None: """Populate metadata overall from access section in config + allowed keys. Access should be possible to change per object, based on user input. @@ -341,7 +341,7 @@ def _populate_meta_access(self): if self.dataio: self.meta_access = generate_meta_access(self.dataio.config) - def _populate_meta_display(self): + def _populate_meta_display(self) -> None: """Populate the display block.""" # display.name @@ -352,14 +352,14 @@ def _populate_meta_display(self): self.meta_display = {"name": display_name} - def _populate_meta_xpreprocessed(self): + def _populate_meta_xpreprocessed(self) -> None: """Populate a few necessary 'tmp' metadata needed for preprocessed data.""" if self.dataio.fmu_context == "preprocessed": self.meta_xpreprocessed["name"] = self.dataio.name self.meta_xpreprocessed["tagname"] = self.dataio.tagname self.meta_xpreprocessed["subfolder"] = self.dataio.subfolder - def _reuse_existing_metadata(self, meta): + def _reuse_existing_metadata(self, meta: dict) -> dict: """Perform a merge procedure if the key `reuse_metadata_rule` is active.""" if self.dataio and self.dataio.reuse_metadata_rule: oldmeta = self.meta_existing @@ -373,7 +373,9 @@ def _reuse_existing_metadata(self, meta): ) return meta - def generate_export_metadata(self, skip_null=True) -> dict: # TODO! -> skip_null? + def generate_export_metadata( + self, skip_null: bool = True + ) -> dict: # TODO! -> skip_null? """Main function to generate the full metadata""" # populate order matters, in particular objectdata provides input to class/file diff --git a/src/fmu/dataio/_objectdata_provider.py b/src/fmu/dataio/_objectdata_provider.py index 36ea6e51f..efa676dfb 100644 --- a/src/fmu/dataio/_objectdata_provider.py +++ b/src/fmu/dataio/_objectdata_provider.py @@ -83,11 +83,13 @@ - Depth surfaces extracted from the structural model """ +from __future__ import annotations + import logging from dataclasses import dataclass, field from datetime import datetime as dt from pathlib import Path -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, cast from warnings import warn import numpy as np @@ -143,7 +145,7 @@ class _ObjectDataProvider: time0: str = field(default="", init=False) time1: str = field(default="", init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.info("Ran __post_init__") def _derive_name_stratigraphy(self) -> dict: @@ -157,7 +159,7 @@ def _derive_name_stratigraphy(self) -> dict: """ logger.info("Evaluate data:name attribute and stratigraphy") - result = dict() # shorter form + result = dict[str, object]() name = self.dataio.name @@ -190,7 +192,9 @@ def _derive_name_stratigraphy(self) -> dict: return result @staticmethod - def _validate_get_ext(fmt, subtype, validator): + def _validate_get_ext( + fmt: str, subtype: str, validator: dict[str, object] + ) -> object | None: """Validate that fmt (file format) matches data and return legal extension.""" if fmt not in validator.keys(): raise ConfigurationError( @@ -198,13 +202,12 @@ def _validate_get_ext(fmt, subtype, validator): f"Valid {subtype} formats are: {list(validator.keys())}", ) - ext = validator.get(fmt, None) - return ext + return validator.get(fmt, None) - def _derive_objectdata(self): + def _derive_objectdata(self) -> dict: """Derive object spesific data.""" logger.info("Evaluate data settings for object") - result = dict() + result = dict[str, object]() if isinstance(self.obj, xtgeo.RegularSurface): result["subtype"] = "RegularSurface" @@ -213,7 +216,9 @@ def _derive_objectdata(self): result["efolder"] = "maps" result["fmt"] = self.dataio.surface_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().surface + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().surface, ) result["spec"], result["bbox"] = self._derive_spec_bbox_regularsurface() @@ -224,7 +229,9 @@ def _derive_objectdata(self): result["efolder"] = "polygons" result["fmt"] = self.dataio.polygons_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().polygons + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().polygons, ) result["spec"], result["bbox"] = self._derive_spec_bbox_polygons() @@ -235,7 +242,9 @@ def _derive_objectdata(self): result["efolder"] = "points" result["fmt"] = self.dataio.points_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().points + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().points, ) result["spec"], result["bbox"] = self._derive_spec_bbox_points() @@ -246,7 +255,9 @@ def _derive_objectdata(self): result["efolder"] = "cubes" result["fmt"] = self.dataio.cube_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().cube + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().cube, ) result["spec"], result["bbox"] = self._derive_spec_bbox_cube() @@ -257,7 +268,9 @@ def _derive_objectdata(self): result["efolder"] = "grids" result["fmt"] = self.dataio.grid_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().grid + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().grid, ) result["spec"], result["bbox"] = self._derive_spec_bbox_cpgrid() @@ -268,7 +281,9 @@ def _derive_objectdata(self): result["efolder"] = "grids" result["fmt"] = self.dataio.grid_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().grid + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().grid, ) result["spec"], result["bbox"] = self._derive_spec_bbox_cpgridproperty() @@ -281,7 +296,9 @@ def _derive_objectdata(self): result["efolder"] = "tables" result["fmt"] = self.dataio.table_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().table + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().table, ) result["spec"], result["bbox"] = self._derive_spec_bbox_dataframe() @@ -294,7 +311,9 @@ def _derive_objectdata(self): result["efolder"] = "tables" result["fmt"] = self.dataio.arrow_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().table + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().table, ) result["spec"], result["bbox"] = self._derive_spec_bbox_arrowtable() @@ -305,7 +324,9 @@ def _derive_objectdata(self): result["efolder"] = "dictionaries" result["fmt"] = self.dataio.dict_fformat result["extension"] = self._validate_get_ext( - result["fmt"], result["subtype"], _ValidFormats().dictionary + cast(str, result["fmt"]), + cast(str, result["subtype"]), + _ValidFormats().dictionary, ) result["spec"], result["bbox"] = self._derive_spec_bbox_dict() @@ -327,7 +348,7 @@ def _derive_objectdata(self): return result - def _derive_spec_bbox_regularsurface(self): + def _derive_spec_bbox_regularsurface(self) -> tuple[dict, dict]: """Process/collect the data.spec and data.bbox for RegularSurface""" logger.info("Derive bbox and specs for RegularSurface") regsurf = self.obj @@ -351,7 +372,7 @@ def _derive_spec_bbox_regularsurface(self): return specs, bbox - def _derive_spec_bbox_polygons(self): + def _derive_spec_bbox_polygons(self) -> tuple[dict, dict]: """Process/collect the data.spec and data.bbox for Polygons""" logger.info("Derive bbox and specs for Polygons") poly = self.obj @@ -370,13 +391,13 @@ def _derive_spec_bbox_polygons(self): bbox["zmax"] = float(zmax) return specs, bbox - def _derive_spec_bbox_points(self): + def _derive_spec_bbox_points(self) -> tuple[dict[str, object], dict[str, object]]: """Process/collect the data.spec and data.bbox for Points""" logger.info("Derive bbox and specs for Points") pnts = self.obj - specs = dict() - bbox = dict() + specs = dict[str, object]() + bbox = dict[str, object]() if len(pnts.dataframe.columns) > 3: attrnames = pnts.dataframe.columns[3:] @@ -392,7 +413,7 @@ def _derive_spec_bbox_points(self): return specs, bbox - def _derive_spec_bbox_cube(self): + def _derive_spec_bbox_cube(self) -> tuple[dict, dict]: """Process/collect the data.spec and data.bbox Cube""" logger.info("Derive bbox and specs for Cube") cube = self.obj @@ -429,7 +450,7 @@ def _derive_spec_bbox_cube(self): return specs, bbox - def _derive_spec_bbox_cpgrid(self): + def _derive_spec_bbox_cpgrid(self) -> tuple[dict, dict]: """Process/collect the data.spec and data.bbox CornerPoint Grid geometry""" logger.info("Derive bbox and specs for Gride (geometry)") grid = self.obj @@ -453,55 +474,55 @@ def _derive_spec_bbox_cpgrid(self): bbox["zmax"] = round(float(geox["zmax"]), 4) return specs, bbox - def _derive_spec_bbox_cpgridproperty(self): + def _derive_spec_bbox_cpgridproperty(self) -> tuple[dict, dict]: """Process/collect the data.spec and data.bbox GridProperty""" logger.info("Derive bbox and specs for GridProperty") gridprop = self.obj - specs = dict() - bbox = dict() + specs = dict[str, object]() + bbox = dict[str, object]() specs["ncol"] = gridprop.ncol specs["nrow"] = gridprop.nrow specs["nlay"] = gridprop.nlay return specs, bbox - def _derive_spec_bbox_dataframe(self): + def _derive_spec_bbox_dataframe( + self, + ) -> tuple[dict[str, object], dict[str, object],]: """Process/collect the data items for DataFrame.""" logger.info("Process data metadata for DataFrame (tables)") dfr = self.obj - specs = dict() - bbox = dict() + specs = dict[str, object]() + bbox = dict[str, object]() specs["columns"] = list(dfr.columns) specs["size"] = int(dfr.size) return specs, bbox - def _derive_spec_bbox_arrowtable(self): + def _derive_spec_bbox_arrowtable( + self, + ) -> tuple[dict[str, object], dict[str, object],]: """Process/collect the data items for Arrow table.""" logger.info("Process data metadata for arrow (tables)") table = self.obj - specs = dict() - bbox = dict() + specs = dict[str, object]() + bbox = dict[str, object]() specs["columns"] = list(table.column_names) specs["size"] = table.num_columns * table.num_rows return specs, bbox - def _derive_spec_bbox_dict(self): + def _derive_spec_bbox_dict(self) -> tuple[dict[str, object], dict[str, object]]: """Process/collect the data items for dictionary.""" logger.info("Process data metadata for dictionary") + return {}, {} - specs = dict() - bbox = dict() - - return specs, bbox - - def _get_columns(self): + def _get_columns(self) -> list[str]: """Get the columns from table""" if isinstance(self.obj, pd.DataFrame): logger.debug("pandas") @@ -512,7 +533,7 @@ def _get_columns(self): logger.debug("Available columns in table %s ", columns) return columns - def _derive_index(self): + def _derive_index(self) -> list[str]: """Derive table index""" # This could in the future also return context columns = self._get_columns() @@ -535,7 +556,7 @@ def _derive_index(self): self._check_index(index) return index - def _check_index(self, index): + def _check_index(self, index: list[str]) -> None: """Check the table index. Args: index (list): list of column names @@ -548,7 +569,7 @@ def _check_index(self, index): for not_found in not_founds: raise KeyError(f"{not_found} is not in table") - def _derive_timedata(self): + def _derive_timedata(self) -> dict: """Format input timedata to metadata.""" tdata = self.dataio.timedata @@ -561,11 +582,12 @@ def _derive_timedata(self): timedata = self._derive_timedata_newformat() return timedata - def _derive_timedata_legacy(self): + def _derive_timedata_legacy(self) -> dict[str, object]: """Format input timedata to metadata. legacy version.""" + # TODO(JB): Covnert tresult to TypedDict or Dataclass. tdata = self.dataio.timedata - tresult = dict() + tresult = dict[str, object]() tresult["time"] = list() if len(tdata) == 1: elem = tdata[0] @@ -574,7 +596,7 @@ def _derive_timedata_legacy(self): self.time0 = str(elem[0]) if len(elem) == 2: xfield["label"] = elem[1] - tresult["time"].append(xfield) + cast(list, tresult["time"]).append(xfield) if len(tdata) == 2: elem1 = tdata[0] xfield1 = {"value": dt.strptime(str(elem1[0]), "%Y%m%d").isoformat()} @@ -587,19 +609,19 @@ def _derive_timedata_legacy(self): xfield2["label"] = elem2[1] if xfield1["value"] < xfield2["value"]: - tresult["time"].append(xfield1) - tresult["time"].append(xfield2) + cast(list, tresult["time"]).append(xfield1) + cast(list, tresult["time"]).append(xfield2) else: - tresult["time"].append(xfield2) - tresult["time"].append(xfield1) + cast(list, tresult["time"]).append(xfield2) + cast(list, tresult["time"]).append(xfield1) - self.time0 = tresult["time"][0]["value"] - self.time1 = tresult["time"][1]["value"] + self.time0 = cast(list, tresult["time"])[0]["value"] + self.time1 = cast(list, tresult["time"])[1]["value"] logger.info("Timedata: time0 is %s while time1 is %s", self.time0, self.time1) return tresult - def _derive_timedata_newformat(self): + def _derive_timedata_newformat(self) -> dict[str, object]: """Format input timedata to metadata, new format. When using two dates, input convention is [[newestdate, "monitor"], [oldestdate, @@ -609,7 +631,7 @@ def _derive_timedata_newformat(self): set for those who wants it turned around). """ tdata = self.dataio.timedata - tresult = dict() + tresult = dict[str, object]() if len(tdata) == 1: elem = tdata[0] @@ -637,17 +659,18 @@ def _derive_timedata_newformat(self): tresult["t0"] = xfield2 tresult["t1"] = xfield1 - self.time0 = tresult["t0"]["value"] - self.time1 = tresult["t1"]["value"] + self.time0 = cast(dict, tresult["t0"])["value"] + self.time1 = cast(dict, tresult["t1"])["value"] logger.info("Timedata: time0 is %s while time1 is %s", self.time0, self.time1) return tresult - def _derive_from_existing(self): + def _derive_from_existing(self) -> None: """Derive from existing metadata.""" # do not change any items in 'data' block, as it may ruin e.g. stratigrapical # setting (i.e. changing data.name is not allowed) + assert self.meta_existing is not None self.metadata = self.meta_existing["data"] self.name = self.meta_existing["data"]["name"] @@ -694,7 +717,7 @@ def _process_content(self) -> Tuple[str, Optional[dict]]: return content, content_spesific - def derive_metadata(self): + def derive_metadata(self) -> None: """Main function here, will populate the metadata block for 'data'.""" logger.info("Derive all metadata for data object...") diff --git a/src/fmu/dataio/_oyaml.py b/src/fmu/dataio/_oyaml.py index 7722483f0..3958818d9 100644 --- a/src/fmu/dataio/_oyaml.py +++ b/src/fmu/dataio/_oyaml.py @@ -1,4 +1,5 @@ # flake8: noqa +# type: ignore # Borrowed from OYAML 0.5 on the internet... import sys from collections import OrderedDict diff --git a/src/fmu/dataio/_utils.py b/src/fmu/dataio/_utils.py index 014f7769d..08c2a597e 100644 --- a/src/fmu/dataio/_utils.py +++ b/src/fmu/dataio/_utils.py @@ -1,4 +1,6 @@ """Module for private utilities/helpers for DataIO class.""" +from __future__ import annotations +import contextlib import hashlib import json import logging @@ -10,7 +12,7 @@ from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast import pandas as pd # type: ignore import yaml @@ -32,17 +34,6 @@ logger = logging.getLogger(__name__) -def inherit_docstring(inherit_from): - """Local decorator to inherit a docstring""" - - def decorator_set_docstring(func): - if func.__doc__ is None and inherit_from.__doc__ is not None: - func.__doc__ = inherit_from.__doc__ - return func - - return decorator_set_docstring - - def detect_inside_rms() -> bool: """Detect if 'truly' inside RMS GUI, where predefined variable project exist. @@ -90,7 +81,9 @@ def drop_nones(dinput: dict) -> dict: return dd -def export_metadata_file(yfile, metadata, savefmt="yaml", verbosity="WARNING") -> None: +def export_metadata_file( + yfile: str, metadata: dict, savefmt: str = "yaml", verbosity: str = "WARNING" +) -> None: """Export genericly and ordered to the complementary metadata file.""" logger.setLevel(level=verbosity) if metadata: @@ -113,7 +106,9 @@ def export_metadata_file(yfile, metadata, savefmt="yaml", verbosity="WARNING") - logger.info("Yaml file on: %s", yfile) -def export_file(obj, filename, extension, flag=None): +def export_file( + obj: object, filename: str, extension: str, flag: str | None = None +) -> str: """Export a valid object to file""" if isinstance(obj, Path): @@ -123,6 +118,7 @@ def export_file(obj, filename, extension, flag=None): obj.to_file(filename, fformat="irap_binary") elif extension == ".csv" and isinstance(obj, (xtgeo.Polygons, xtgeo.Points)): out = obj.copy() # to not modify incoming instance! + assert flag is not None if "xtgeo" not in flag: out.xname = "X" out.yname = "Y" @@ -158,7 +154,7 @@ def export_file(obj, filename, extension, flag=None): return str(filename) -def md5sum(fname): +def md5sum(fname: str) -> str: hash_md5 = hashlib.md5() with open(fname, "rb") as fil: for chunk in iter(lambda: fil.read(4096), b""): @@ -166,14 +162,21 @@ def md5sum(fname): return hash_md5.hexdigest() -def export_file_compute_checksum_md5(obj, filename, extension, flag=None, tmp=False): +def export_file_compute_checksum_md5( + obj: object, + filename: str, + extension: str, + flag: str | None = None, + tmp: bool = False, +) -> tuple[str | Path | None, str,]: """Export and compute checksum, with possibility to use a tmp file.""" - usefile = filename + usefile: str | None = filename if tmp: tmpdir = tempfile.TemporaryDirectory() - usefile = Path(tmpdir.name) / "tmpfile" + usefile = str(Path(tmpdir.name) / "tmpfile") + assert usefile is not None export_file(obj, usefile, extension, flag=flag) checksum = md5sum(usefile) if tmp: @@ -182,7 +185,7 @@ def export_file_compute_checksum_md5(obj, filename, extension, flag=None, tmp=Fa return usefile, checksum -def create_symlink(source, target): +def create_symlink(source: str, target: str) -> None: """Create a symlinked file with some checks.""" thesource = Path(source) @@ -200,16 +203,18 @@ def create_symlink(source, target): raise IOError(f"Target file {thesource} does not exist or is not a symlink.") -def size(fname): +def size(fname: str) -> int: return Path(fname).stat().st_size -def uuid_from_string(string): +def uuid_from_string(string: str) -> str: """Produce valid and repeteable UUID4 as a hash of given string""" return str(uuid.UUID(hashlib.md5(string.encode("utf-8")).hexdigest())) -def read_parameters_txt(pfile: Union[Path, str]) -> Dict[str, Union[str, float, int]]: +def read_parameters_txt( + pfile: Union[Path, str] +) -> Dict[str, Union[str, float, int | None]]: """Read the parameters.txt file and convert to a dict. The parameters.txt file has this structure:: SENSNAME rms_seed @@ -283,28 +288,22 @@ def nested_parameters_dict( return nested_dict -def check_if_number(value): +def check_if_number(value: str | None) -> int | float | str | None: """Check if value (str) looks like a number and return the converted value.""" if value is None: - return + return None - res = None - try: - res = int(value) - except ValueError: - try: - res = float(value) - except ValueError: - pass + with contextlib.suppress(ValueError): + return int(value) - if res is not None: - return res + with contextlib.suppress(ValueError): + return float(value) return value -def get_object_name(obj): +def get_object_name(obj: Path) -> str | None: """Get the name of the object. If not possible, return None. @@ -320,19 +319,19 @@ def get_object_name(obj): name = obj.name except AttributeError: logger.info("display.name could not be set") - return + return None if isinstance(obj, xtgeo.RegularSurface) and name == "unknown": logger.debug("Got 'unknown' as name from a surface object, returning None") - return + return None if isinstance(obj, xtgeo.Polygons) and name == "poly": logger.debug("Got 'poly' as name from a polygons object, returning None") - return + return None if isinstance(obj, xtgeo.Grid) and name == "noname": logger.debug("Got 'noname' as name from grids object, returning None") - return + return None return name @@ -342,7 +341,7 @@ def prettyprint_dict(inp: dict) -> str: return str(json.dumps(inp, indent=2, default=str, ensure_ascii=False)) -def some_config_from_env(envvar="FMU_GLOBAL_CONFIG") -> dict: +def some_config_from_env(envvar: str = "FMU_GLOBAL_CONFIG") -> dict | None: """Get the config from environment variable. This function is only called if config SHALL be fetched from the environment @@ -354,14 +353,12 @@ def some_config_from_env(envvar="FMU_GLOBAL_CONFIG") -> dict: cfg_path = os.environ[envvar] else: warnings.warn( - ( - "No config was received. " - "The config should be given explicitly as an input argument, or " - f"the environment variable {envvar} must point to a valid yaml file. " - "A missing config will still export a file, but without a metadata " - "file. Such exports may be disabled in a future version of fmu.dataio", - UserWarning, - ) + "No config was received. " + "The config should be given explicitly as an input argument, or " + f"the environment variable {envvar} must point to a valid yaml file. " + "A missing config will still export a file, but without a metadata " + "file. Such exports may be disabled in a future version of fmu.dataio", + UserWarning, ) return None @@ -433,7 +430,9 @@ def read_metadata(filename: Union[str, Path]) -> dict: return metacfg -def glue_metadata_preprocessed(oldmeta, newmeta): +def glue_metadata_preprocessed( + oldmeta: dict[str, object], newmeta: dict[str, object] +) -> dict[str, object]: """Glue (combine) to metadata dicts according to rule 'preprocessed'.""" meta = oldmeta.copy() @@ -444,13 +443,15 @@ def glue_metadata_preprocessed(oldmeta, newmeta): meta["fmu"] = newmeta["fmu"] meta["file"] = newmeta["file"] - newmeta["tracklog"][-1]["event"] = "merged" - meta["tracklog"].extend(newmeta["tracklog"]) + cast(list, newmeta["tracklog"])[-1]["event"] = "merged" + cast(list, meta["tracklog"]).extend(cast(list, newmeta["tracklog"])) return meta -def parse_timedata(datablock: dict, isoformat=True): +def parse_timedata( + datablock: dict, isoformat: bool = True +) -> tuple[str | None, str | None]: """The time section under datablock has variants to parse. Formats:: @@ -494,7 +495,7 @@ def parse_timedata(datablock: dict, isoformat=True): if isinstance(datablock["time"], list): date0 = datablock["time"][0]["value"] - if len(datablock["time"] == 2): + if len(datablock["time"]) == 2: date1 = datablock["time"][1]["value"] elif isinstance(datablock["time"], dict): @@ -505,10 +506,10 @@ def parse_timedata(datablock: dict, isoformat=True): if not isoformat: if date0: tdate0 = datetime.strptime(date0, "%Y-%m-%dT%H:%M:%S") - date0 = tdate0.datetime.strftime("%Y%m%d") + date0 = tdate0.strftime("%Y%m%d") if date1: tdate1 = datetime.strptime(date1, "%Y-%m-%dT%H:%M:%S") - date1 = tdate1.datetime.strftime("%Y%m%d") + date1 = tdate1.strftime("%Y%m%d") return (date0, date1) diff --git a/src/fmu/dataio/dataio.py b/src/fmu/dataio/dataio.py index 54ef5593d..81afb6ab5 100644 --- a/src/fmu/dataio/dataio.py +++ b/src/fmu/dataio/dataio.py @@ -2,6 +2,8 @@ The metadata spec is documented as a JSON schema, stored under schema/. """ +from __future__ import annotations + import logging import os import uuid @@ -9,7 +11,7 @@ from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path -from typing import Any, ClassVar, List, Optional, Tuple, Union +from typing import Any, ClassVar, List, Optional, Tuple, Union, cast from warnings import warn import pandas as pd # type: ignore @@ -53,14 +55,14 @@ class ValidationError(ValueError, KeyError): # ====================================================================================== -def _validate_variable(key, value, legals) -> bool: +def _validate_variable(key: str, value: type, legals: dict[str, str | type]) -> bool: """Use data from __annotions__ to validate that overriden var. is of legal type.""" if key not in legals: logger.warning("Unsupported key, raise an error") raise ValidationError(f"The input key '{key}' is not supported") - if isinstance(legals[key], str): - valid_type = eval(legals[key]) # pylint: disable=eval-used + if isinstance(lkey := legals[key], str): + valid_type = eval(lkey) # pylint: disable=eval-used else: valid_type = legals[key] @@ -233,7 +235,7 @@ def _check_content(proposed: Union[str, dict]) -> Any: return usecontent, content_specific -def _content_validate(name, fields): +def _content_validate(name: str, fields: dict[str, type]) -> None: logger.debug("starting staticmethod _data_process_content_validate") valid = ALLOWED_CONTENTS.get(name, None) if valid is None: @@ -599,7 +601,7 @@ class ExportData: # << NB! storing ACTUAL casepath: _rootpath: Path = field(default_factory=Path, init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) logger.info("Running __post_init__ ExportData") logger.debug("Global config is %s", prettyprint_dict(self.config)) @@ -652,7 +654,7 @@ def __post_init__(self): logger.info("FMU context is %s", self.fmu_context) logger.info("Ran __post_init__") - def _show_deprecations_or_notimplemented(self): + def _show_deprecations_or_notimplemented(self) -> None: """Warn on deprecated keys or on stuff not implemented yet.""" if self.runpath: @@ -669,12 +671,14 @@ def _show_deprecations_or_notimplemented(self): PendingDeprecationWarning, ) - def _validate_content_key(self): + def _validate_content_key(self) -> None: """Validate the given 'content' input.""" - self._usecontent, self._content_specific = _check_content(self.content) + self._usecontent, self._content_specific = _check_content( + cast(str | dict, self.content), + ) - def _validate_fmucontext_key(self): + def _validate_fmucontext_key(self) -> None: """Validate the given 'fmu_context' input.""" if self.fmu_context not in ALLOWED_FMU_CONTEXTS: msg = "" @@ -716,7 +720,7 @@ def _update_check_settings(self, newsettings: dict) -> None: self._validate_fmucontext_key() logger.info("Validate FMU context which is now %s", self.fmu_context) - def _update_globalconfig_from_settings(self): + def _update_globalconfig_from_settings(self) -> None: """A few user settings may update/append the global config directly.""" newglobals = deepcopy(self.config) @@ -732,7 +736,7 @@ def _update_globalconfig_from_settings(self): self.config = newglobals - def _establish_pwd_rootpath(self): + def _establish_pwd_rootpath(self) -> None: """Establish state variables pwd and the (initial) rootpath. The self._pwd stores the process working directory, i.e. the folder @@ -764,7 +768,7 @@ def _establish_pwd_rootpath(self): if self._inside_rms or INSIDE_RMS or "RUN_DATAIO_EXAMPLES" in os.environ: self._rootpath = (self._pwd / "../../.").absolute().resolve() logger.info("Run from inside RMS (or pretend)") - self._inside_rms = True + self._inside_rms = True # BUG(?): Should be ExportData._inside_rms? # make some extra keys in settings: self._usecontext = self.fmu_context # may change later! @@ -809,7 +813,9 @@ def _check_obj_if_file(self, obj: Any) -> Any: # Public methods: # ================================================================================== - def generate_metadata(self, obj: Any, compute_md5: bool = True, **kwargs) -> dict: + def generate_metadata( + self, obj: object, compute_md5: bool = True, **kwargs: object + ) -> dict: """Generate and return the complete metadata for a provided object. An object may be a map, 3D grid, cube, table, etc which is of a known and @@ -861,7 +867,12 @@ def generate_metadata(self, obj: Any, compute_md5: bool = True, **kwargs) -> dic return deepcopy(self._metadata) - def export(self, obj, return_symlink=False, **kwargs) -> str: + def export( + self, + obj: object, + return_symlink: bool = False, + **kwargs: dict[str, object], + ) -> str: """Export data objects of 'known' type to FMU storage solution with metadata. This function will also collect the data spesific class metadata. For "classic" @@ -977,7 +988,7 @@ class InitializeCase: # pylint: disable=too-few-public-methods _pwd: Path = field(default_factory=Path, init=False) _casepath: Path = field(default_factory=Path, init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) if not self.config or GLOBAL_ENVNAME in os.environ: @@ -1010,7 +1021,7 @@ def _update_settings(self, newsettings: dict) -> None: logger.setLevel(level=self.verbosity) logger.info("New setting OK for %s", setting) - def _establish_pwd_casepath(self): + def _establish_pwd_casepath(self) -> None: """Establish state variables pwd and casepath. See ExportData's method but this is much simpler (e.g. no RMS context) @@ -1031,7 +1042,7 @@ def _establish_pwd_casepath(self): logger.info("Set PWD (case): %s", str(self._pwd)) logger.info("Set rootpath (case): %s", str(self._casepath)) - def _check_already_metadata_or_create_folder(self, force=False) -> bool: + def _check_already_metadata_or_create_folder(self, force: bool = False) -> bool: if not self._casepath.exists(): self._casepath.mkdir(parents=True, exist_ok=True) logger.info("Created rootpath (case) %s", self._casepath) @@ -1054,7 +1065,10 @@ def _check_already_metadata_or_create_folder(self, force=False) -> bool: # ================================================================================== def generate_metadata( - self, force: bool = False, skip_null=True, **kwargs + self, + force: bool = False, + skip_null: bool = True, + **kwargs: object, ) -> Union[dict, None]: """Generate case metadata. @@ -1091,8 +1105,9 @@ def generate_metadata( # only asset, not ssdl access = _metadata.generate_meta_access(self.config) + assert access is not None meta["access"] = dict() - meta["access"]["asset"] = access["asset"] + meta["access"]["asset"] = cast(dict, access["asset"]) meta["fmu"] = dict() meta["fmu"]["model"] = self.config["model"] @@ -1118,7 +1133,12 @@ def generate_metadata( # alias generate_case_metadata = generate_metadata - def export(self, force: bool = False, skip_null=True, **kwargs) -> Union[str, None]: + def export( + self, + force: bool = False, + skip_null: bool = True, + **kwargs: dict[str, object], + ) -> Union[str, None]: """Export case metadata to file. Args: @@ -1197,7 +1217,7 @@ class AggregatedData: _metadata: dict = field(default_factory=dict, init=False) _metafile: Path = field(default_factory=Path, init=False) - def __post_init__(self): + def __post_init__(self) -> None: logger.setLevel(level=self.verbosity) @staticmethod @@ -1225,7 +1245,7 @@ def _update_settings(self, newsettings: dict) -> None: logger.setLevel(level=self.verbosity) logger.info("New setting OK for %s", setting) - def _construct_filename(self, template: dict) -> Tuple[Path, Path]: + def _construct_filename(self, template: dict) -> Tuple[Path, Path | None]: """Construct the paths/filenames for aggregated data. These filenames are constructed a bit different than in a forward job, since we @@ -1328,8 +1348,12 @@ def _construct_filename(self, template: dict) -> Tuple[Path, Path]: return relname, absname def _generate_aggrd_metadata( - self, obj: Any, real_ids: List[int], uuids: List[str], compute_md5: bool = True - ): + self, + obj: object, + real_ids: List[int], + uuids: List[str], + compute_md5: bool = True, + ) -> None: logger.info( "self.aggregation is %s (%s)", self.aggregation_id, @@ -1397,10 +1421,10 @@ def _generate_aggrd_metadata( def generate_metadata( self, - obj: Any, + obj: object, compute_md5: bool = True, skip_null: bool = True, - **kwargs, + **kwargs: object, ) -> dict: """Generate metadata for the aggregated data. @@ -1446,17 +1470,17 @@ def generate_metadata( # alias method def generate_aggregation_metadata( self, - obj: Any, + obj: object, compute_md5: bool = True, skip_null: bool = True, - **kwargs, + **kwargs: object, ) -> dict: """Alias method name, see ``generate_metadata``""" return self.generate_metadata( obj, compute_md5=compute_md5, skip_null=skip_null, **kwargs ) - def export(self, obj, **kwargs) -> str: + def export(self, obj: object, **kwargs: object) -> str: """Export aggregated file with metadata to file. Args: diff --git a/src/fmu/dataio/hook_implementations/jobs.py b/src/fmu/dataio/hook_implementations/jobs.py index 9c06e79fb..1bc6c6475 100644 --- a/src/fmu/dataio/hook_implementations/jobs.py +++ b/src/fmu/dataio/hook_implementations/jobs.py @@ -8,5 +8,5 @@ @hook_implementation @plugin_response(plugin_name="fmu_dataio") -def installable_workflow_jobs(): +def installable_workflow_jobs() -> dict: return {} diff --git a/src/fmu/dataio/scripts/create_case_metadata.py b/src/fmu/dataio/scripts/create_case_metadata.py index 0731ae9cf..f83fb43ac 100644 --- a/src/fmu/dataio/scripts/create_case_metadata.py +++ b/src/fmu/dataio/scripts/create_case_metadata.py @@ -8,6 +8,8 @@ pointed towards the produced global_variables, fmu-config should run before this script to make sure global_variables is updated.""" +from __future__ import annotations + import yaml import argparse import logging @@ -73,7 +75,7 @@ class WfCreateCaseMetadata(ErtScript): # name in fmu-dataio # pylint: disable=too-few-public-methods - def run(self, *args) -> None: + def run(self, *args: str) -> None: # pylint: disable=no-self-use """Parse arguments and call _create_case_metadata_main()""" parser = get_parser() @@ -81,18 +83,19 @@ def run(self, *args) -> None: create_case_metadata_main(workflow_args) -def create_case_metadata_main(args) -> None: +def create_case_metadata_main(args: argparse.Namespace) -> None: """Create the case metadata and register case on Sumo.""" logger.setLevel(level=args.verbosity) check_arguments(args) case_metadata_path = create_metadata(args) + assert case_metadata_path is not None register_on_sumo(args, case_metadata_path) logger.debug("create_case_metadata.py has finished.") -def create_metadata(args) -> str: +def create_metadata(args: argparse.Namespace) -> str | None: """Create the case metadata and print them to the disk""" _global_variables_path = Path(args.ert_config_path, args.global_variables_path) global_variables = _parse_yaml(_global_variables_path) @@ -112,7 +115,10 @@ def create_metadata(args) -> str: return case_metadata_path -def register_on_sumo(args, case_metadata_path) -> str: +def register_on_sumo( + args: argparse.Namespace, + case_metadata_path: str, +) -> str | None: """Register the case on Sumo by sending the case metadata""" env = args.sumo_env @@ -121,7 +127,7 @@ def register_on_sumo(args, case_metadata_path) -> str: logger.info("Registering case on Sumo (%s)", env) else: logger.info("Sumo registration has been deactivated through arguments") - return + return None # lazy loading of Sumo dependencies from fmu.sumo.uploader import SumoConnection, CaseOnDisk @@ -141,7 +147,7 @@ def register_on_sumo(args, case_metadata_path) -> str: return sumo_id -def _parse_yaml(path): +def _parse_yaml(path: Path) -> dict: """Parse the global variables, return as dict""" with open(path, "r") as stream: @@ -150,7 +156,7 @@ def _parse_yaml(path): return data -def check_arguments(args): +def check_arguments(args: argparse.Namespace) -> None: """Do basic sanity checks of input""" logger.debug("Checking input arguments") @@ -189,7 +195,7 @@ def get_parser() -> argparse.ArgumentParser: @hook_implementation -def legacy_ertscript_workflow(config) -> None: +def legacy_ertscript_workflow(config: object) -> None: """Hook the WfCreateCaseMetadata class with documentation into ERT.""" workflow = config.add_workflow(WfCreateCaseMetadata, "WF_CREATE_CASE_METADATA") workflow.parser = get_parser