From e0c26928e6454be372bc5443acc85453cd49102e Mon Sep 17 00:00:00 2001 From: Eivind Jahren Date: Fri, 8 Mar 2024 14:25:11 +0100 Subject: [PATCH] Use a compiled regex for matching in summary looping over a fnmatch was simply too slow --- src/ert/config/_read_summary.py | 56 +++++++++++++++++--- tests/unit_tests/config/test_read_summary.py | 17 ++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/ert/config/_read_summary.py b/src/ert/config/_read_summary.py index ba2da3e494e..9acdfb58993 100644 --- a/src/ert/config/_read_summary.py +++ b/src/ert/config/_read_summary.py @@ -1,12 +1,22 @@ from __future__ import annotations +import fnmatch import os import os.path import re from datetime import datetime, timedelta from enum import Enum, auto -from fnmatch import fnmatch -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np import numpy.typing as npt @@ -72,7 +82,7 @@ def from_keyword(cls, summary_keyword: str) -> _SummaryType: return cls.REGION if any( - re.match(pattern, summary_keyword) + re.fullmatch(pattern, summary_keyword) for pattern in [r"R.FT.*", r"R..FT.*", r"R.FR.*", r"R..FR.*", r"R.F"] ): return cls.INTER_REGION @@ -249,6 +259,38 @@ def _check_vals( return vals +def _fetch_keys_to_matcher(fetch_keys: Sequence[str]) -> Callable[[str], bool]: + """ + Transform the list of keys (with * used as repeated wildcard) into + a matcher. + + >>> match = _fetch_keys_to_matcher([""]) + >>> match("FOPR") + False + + >>> match = _fetch_keys_to_matcher(["*"]) + >>> match("FOPR"), match("FO*") + (True, True) + + + >>> match = _fetch_keys_to_matcher(["F*PR"]) + >>> match("WOPR"), match("FOPR"), match("FGPR"), match("SOIL") + (False, True, True, False) + + >>> match = _fetch_keys_to_matcher(["WGOR:*"]) + >>> match("FOPR"), match("WGOR:OP1"), match("WGOR:OP2"), match("WGOR") + (False, True, True, False) + + >>> match = _fetch_keys_to_matcher(["FOPR", "FGPR"]) + >>> match("FOPR"), match("FGPR"), match("WGOR:OP2"), match("WGOR") + (True, True, False, False) + """ + if not fetch_keys: + return lambda _: False + regex = re.compile("|".join(fnmatch.translate(key) for key in fetch_keys)) + return lambda s: regex.fullmatch(s) is not None + + def _read_spec( spec: str, fetch_keys: Sequence[str] ) -> Tuple[int, datetime, DateUnit, List[str], npt.NDArray[np.int64]]: @@ -349,6 +391,8 @@ def _read_spec( index_mapping: Dict[str, int] = {} date_index = None + should_load_key = _fetch_keys_to_matcher(fetch_keys) + def optional_get(arr: Optional[npt.NDArray[Any]], idx: int) -> Any: if arr is None: return None @@ -373,7 +417,7 @@ def optional_get(arr: Optional[npt.NDArray[Any]], idx: int) -> Any: lk = optional_get(numlz, i) key = make_summary_key(keyword, num, name, nx, ny, lgr_name, li, lj, lk) - if key is not None and _should_load_summary_key(key, fetch_keys): + if key is not None and should_load_key(key): if key in index_mapping: # only keep the index of the last occurrence of a key # this is done for backwards compatability @@ -470,7 +514,3 @@ def read_params() -> None: read_params() read_params() return np.array(values, dtype=np.float32).T, dates - - -def _should_load_summary_key(data_key: Any, user_set_keys: Sequence[str]) -> bool: - return any(fnmatch(data_key, key) for key in user_set_keys) diff --git a/tests/unit_tests/config/test_read_summary.py b/tests/unit_tests/config/test_read_summary.py index 6597fc58f1d..030b0b0e617 100644 --- a/tests/unit_tests/config/test_read_summary.py +++ b/tests/unit_tests/config/test_read_summary.py @@ -486,3 +486,20 @@ def test_that_ambiguous_case_restart_raises_an_informative_error( match="Ambiguous reference to unified summary", ): read_summary(str(tmp_path / "test"), ["*"]) + + +@given(summaries()) +def test_that_length_of_fetch_keys_does_not_reduce_performance( + tmp_path_factory, summary +): + """With a compiled regex this takes seconds to run, and with + a naive implementation it will take almost an hour. + """ + tmp_path = tmp_path_factory.mktemp("summary") + smspec, unsmry = summary + unsmry.to_file(tmp_path / "TEST.UNSMRY") + smspec.to_file(tmp_path / "TEST.SMSPEC") + fetch_keys = [str(i) for i in range(100000)] + (_, keys, time_map, _) = read_summary(str(tmp_path / "TEST"), fetch_keys) + assert all(k in fetch_keys for k in keys) + assert len(time_map) == len(unsmry.steps)