Skip to content

Commit

Permalink
Add some typing to everest
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Feb 5, 2025
1 parent 336c4a2 commit 386f1ed
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 33 deletions.
186 changes: 185 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
disable_error_code = import-untyped

[mypy-everest.*]
[mypy-everest.api.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
Expand All @@ -79,5 +79,189 @@ disable_error_code = dict-item,
return-value,
name-defined

[mypy-everest.templates.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined

[mypy-everest.config.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined

[mypy-everest.detached.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined

[mypy-everest.jobs.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined


[mypy-everest.plugins.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined


[mypy-everest.bin.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined


[mypy-everest.everest_storage.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined

[mypy-everest.optimizer.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined

[mypy-everest.simulator.*]
disable_error_code = dict-item,
no-untyped-def,
call-overload,
union-attr,
no-untyped-call,
var-annotated,
index,
call-arg,
unused-ignore,
arg-type,
type-arg,
type-var,
assignment,
typeddict-item,
attr-defined,
comparison-overlap,
return-value,
name-defined


[mypy-tests.*]
disable_error_code = no-untyped-def, no-untyped-call, typeddict-item, assignment
21 changes: 13 additions & 8 deletions src/everest/config_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def load_yaml(file_name: str) -> dict[str, Any] | None:
return None


def _get_definitions(configuration, configpath):
def _get_definitions(
configuration: dict[str, Any] | None, configpath: str
) -> dict[str, Any]:
defs = {}
if configuration:
if "definitions" not in configuration:
Expand Down Expand Up @@ -74,7 +76,11 @@ def _get_definitions(configuration, configpath):
return defs


def _os():
class Os:
pass


def _os() -> Os:
"""Return an object whose properties are the users environment variables.
For example, calling os.USER returns the username, os.HOSTNAME returns the
Expand All @@ -83,23 +89,22 @@ def _os():
"""

class Os:
pass

x = Os()
x.__dict__.update(os.environ)
return x


def _render_definitions(definitions, jinja_env):
def _render_definitions(
definitions: dict[str, Any], jinja_env: jinja2.Environment
) -> None:
# pylint: disable=unnecessary-lambda-assignment
render = lambda s, d: jinja_env.from_string(s).render(**d)
for key in definitions:
for key in definitions: # noqa: PLC0206
if not isinstance(definitions[key], str):
continue

for _idx in range(len(definitions) + 1):
new_val = render(definitions[key], definitions)
new_val = render(definitions[key], definitions) # type: ignore
if definitions[key] != new_val:
definitions[key] = new_val
else:
Expand Down
2 changes: 1 addition & 1 deletion src/everest/everest_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ropt.enums import EventType
from ropt.plan import BasicOptimizer, Event
from ropt.results import FunctionResults, GradientResults, convert_to_maximize
from ropt.transforms import OptModelTransforms
from ropt.transforms import OptModelTransforms # type: ignore

from everest.config import EverestConfig
from everest.strings import EVEREST
Expand Down
2 changes: 1 addition & 1 deletion src/everest/optimizer/everest2ropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ropt.config.enopt import EnOptConfig, EnOptContext
from ropt.enums import ConstraintType, PerturbationType, VariableType
from ropt.transforms import OptModelTransforms
from ropt.transforms import OptModelTransforms # type: ignore

from everest.config import (
EverestConfig,
Expand Down
6 changes: 3 additions & 3 deletions src/everest/optimizer/opt_model_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import numpy as np
from numpy.typing import NDArray
from ropt.transforms import OptModelTransforms
from ropt.transforms.base import VariableTransform
from ropt.transforms import OptModelTransforms # type: ignore
from ropt.transforms.base import VariableTransform # type: ignore

from everest.config import ControlConfig
from everest.config.utils import FlattenedControls


class ControlScaler(VariableTransform):
class ControlScaler(VariableTransform): # type: ignore
def __init__(
self,
lower_bounds: Sequence[float],
Expand Down
30 changes: 14 additions & 16 deletions src/everest/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import logging
import os
from datetime import UTC, datetime

from ropt.version import version as ropt_version

Expand All @@ -15,28 +15,28 @@
from opm.io.ecl_state import EclipseState
from opm.io.schedule import Schedule

def has_opm():
def has_opm() -> bool:
return True

except ImportError:

def has_opm():
def has_opm() -> bool:
return False


def version_info():
def version_info() -> str:
return f"everest:'{ert_version}'\nropt:'{ropt_version}'\nert:'{ert_version}'"


def date2str(date):
return datetime.datetime.strftime(date, DATE_FORMAT)
def date2str(date: datetime) -> str:
return datetime.strftime(date, DATE_FORMAT)


def str2date(date_str):
return datetime.datetime.strptime(date_str, DATE_FORMAT)
def str2date(date_str: str) -> datetime:
return datetime.strptime(date_str, DATE_FORMAT)


def makedirs_if_needed(path, roll_if_exists=False):
def makedirs_if_needed(path: str, roll_if_exists: bool = False) -> None:
if os.path.isdir(path):
if not roll_if_exists:
return
Expand All @@ -56,16 +56,14 @@ def warn_user_that_runpath_is_nonempty() -> None:
logging.getLogger(EVEREST).warning("Everest is running in an existing runpath")


def _roll_dir(old_name):
def _roll_dir(old_name: str) -> None:
old_name = os.path.realpath(old_name)
new_name = old_name + datetime.datetime.now(datetime.UTC).strftime(
"__%Y-%m-%d_%H.%M.%S.%f"
)
new_name = old_name + datetime.now(UTC).strftime("__%Y-%m-%d_%H.%M.%S.%f")
os.rename(old_name, new_name)
logging.getLogger(EVEREST).info(f"renamed {old_name} to {new_name}")


def load_deck(fname):
def load_deck(fname: str): # type: ignore
"""Take a .DATA file and return an opm.io.Deck."""
if not os.path.exists(fname):
raise OSError(f'No such data file "{fname}".')
Expand Down Expand Up @@ -97,7 +95,7 @@ def load_deck(fname):
return opm.io.Parser().parse(fname, parse_context)


def read_wellnames(fname):
def read_wellnames(fname: str) -> list[str]:
"""Take a .DATA file and return the list of well
names at time the first timestep from deck."""
deck = load_deck(fname)
Expand All @@ -106,7 +104,7 @@ def read_wellnames(fname):
return [str(well.name) for well in schedule.get_wells(0)]


def read_groupnames(fname):
def read_groupnames(fname: str) -> list[str]:
"""Take a .DATA file and return the list of group
names at the first timestep from deck."""
deck = load_deck(fname)
Expand Down
7 changes: 4 additions & 3 deletions src/everest/util/forward_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TypeVar
from collections.abc import Sequence
from typing import Any, TypeVar

from pydantic import BaseModel, ValidationError

Expand All @@ -9,14 +10,14 @@
T = TypeVar("T", bound=BaseModel)


def collect_forward_model_schemas():
def collect_forward_model_schemas() -> dict[str, Any] | None:
schemas = pm.hook.get_forward_models_schemas()
if schemas:
return schemas.pop()
return {}


def lint_forward_model_job(job: str, args) -> list[str]:
def lint_forward_model_job(job: str, args: Sequence[str]) -> list[str]:
return pm.hook.lint_forward_model(job=job, args=args)


Expand Down

0 comments on commit 386f1ed

Please sign in to comment.